Unverified Commit cf7be057 authored by Lei Wang's avatar Lei Wang Committed by GitHub
Browse files

[Bugfix] Add missing FP8 header include (#752)



* [Enhancement] Add DispatchInstruction specialization for fp8 types in gemm_sm90.h

- Introduced specialized DispatchInstruction templates for fp8_e4_t and fp8_e5_t types, enhancing support for new data formats in CUDA GEMM operations.
- Each specialization defines the corresponding MMA and MMA_Group types, optimizing performance for specific configurations.
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>

* [Enhancement] Include cuda_fp8.h in gemm_sm90.h

- Added the inclusion of the "cuda_fp8.h" header file to support new data formats in CUDA GEMM operations, enhancing compatibility with recent updates for fp8 types.
Co-authored-by: default avatarLeiWang1999 <leiwang1999@outlook.com>

* lint fix

* [Refactor] Remove unused tl_shuffle_elect and related functions from common.h

- Deleted the `tl_shuffle_elect` function and its associated comments to streamline the codebase.
- Added inclusion of "intrin.h" for improved intrinsic support in CUDA operations.
- Cleaned up the file by removing unnecessary template parameters and functions, enhancing clarity and maintainability.

* lint fix

* [Refactor] Update header inclusions in common.h and gemm_sm90.h

- Removed the inclusion of "intrin.h" from common.h to streamline dependencies.
- Added "intrin.h" inclusion in gemm_sm90.h to ensure intrinsic support for CUDA operations, enhancing functionality and maintainability.

* bug fix
parent c2fe91e0
...@@ -240,53 +240,4 @@ template <int barrier_id = 0, int thread_count = 0> ...@@ -240,53 +240,4 @@ template <int barrier_id = 0, int thread_count = 0>
TL_DEVICE void __sync_thread_partial() { TL_DEVICE void __sync_thread_partial() {
asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count)); asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(thread_count));
} }
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
// thread.
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
// Special case: thread_extent == 0 means "elect exactly one thread
// in the entire thread block", i.e., the leader of the first warp of the
// block.
if constexpr (thread_extent == 0) {
// cutlass::canonical_warp_idx_sync():
// Returns the warp ID within the thread block in a "canonical" way
// (0 for the first warp, 1 for the second, ...).
// cute::elect_one_sync():
// Elect exactly one lane in the warp to return true (typically lane 0),
// other lanes return false.
// The condition ensures that:
// (1) We are in warp 0 of the block.
// (2) We are the elected lane in this warp.
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
}
// General case: thread_extent != 0
// (threadIdx.x / 32) is the warp index in the block.
// (thread_extent / 32) is the number of warps in one group of size
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
// within the group.
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
// lanes in the warp. Here it broadcasts the group-local warp index from lane
// 0. Comparing to 0 selects only the group's warp 0.
return __shfl_sync(0xffffffff, // full warp mask
(threadIdx.x / 32) %
(thread_extent / 32), // warp index within group
0 // take the value from lane 0
) == 0 &&
// Within that group leader warp, elect exactly one lane (typically
// lane 0) to be the single representative for the group.
cute::elect_one_sync();
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl } // namespace tl
#pragma once #pragma once
#include "common.h"
#include "cuda_fp8.h"
#include "intrin.h"
#include <cute/arch/mma_sm80.hpp> #include <cute/arch/mma_sm80.hpp>
#include <cute/arch/mma_sm90.hpp> #include <cute/arch/mma_sm90.hpp>
#include <cute/atom/mma_atom.hpp> #include <cute/atom/mma_atom.hpp>
...@@ -7,8 +10,6 @@ ...@@ -7,8 +10,6 @@
#include <cutlass/cutlass.h> #include <cutlass/cutlass.h>
#include <cutlass/gemm/collective/collective_builder.hpp> #include <cutlass/gemm/collective/collective_builder.hpp>
#include "common.h"
namespace cute { namespace cute {
using namespace SM90; using namespace SM90;
......
#pragma once
#if __CUDA_ARCH_LIST__ >= 900
#include "cute/arch/cluster_sm90.hpp"
#include "cutlass/cutlass.h"
namespace tl {
// Template parameter:
// thread_extent: the logical size (in number of threads) of each "group"
// within which we want to elect exactly ONE representative
// thread.
template <int thread_extent> TL_DEVICE bool tl_shuffle_elect() {
// Special case: thread_extent == 0 means "elect exactly one thread
// in the entire thread block", i.e., the leader of the first warp of the
// block.
if constexpr (thread_extent == 0) {
// cutlass::canonical_warp_idx_sync():
// Returns the warp ID within the thread block in a "canonical" way
// (0 for the first warp, 1 for the second, ...).
// cute::elect_one_sync():
// Elect exactly one lane in the warp to return true (typically lane 0),
// other lanes return false.
// The condition ensures that:
// (1) We are in warp 0 of the block.
// (2) We are the elected lane in this warp.
return cutlass::canonical_warp_idx_sync() == 0 && cute::elect_one_sync();
}
// General case: thread_extent != 0
// (threadIdx.x / 32) is the warp index in the block.
// (thread_extent / 32) is the number of warps in one group of size
// thread_extent. We take warp_id % num_warps_in_group to get the warp's index
// within the group.
// __shfl_sync(mask, value, srcLane): broadcast 'value' from srcLane to all
// lanes in the warp. Here it broadcasts the group-local warp index from lane
// 0. Comparing to 0 selects only the group's warp 0.
return __shfl_sync(0xffffffff, // full warp mask
(threadIdx.x / 32) %
(thread_extent / 32), // warp index within group
0 // take the value from lane 0
) == 0 &&
// Within that group leader warp, elect exactly one lane (typically
// lane 0) to be the single representative for the group.
cute::elect_one_sync();
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_alloc() {
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
template <uint32_t RegCount> TL_DEVICE void warpgroup_reg_dealloc() {
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount));
}
} // namespace tl
#endif
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment