Commit 1ba8a08f authored by carlushuang's avatar carlushuang
Browse files

update tmp work

parent bf214665
...@@ -252,7 +252,7 @@ struct FusedMoeKernel ...@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// allocate LDS // allocate LDS
__shared__ char smem_ptr[GetSmemSize()]; // __shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane( ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr)); *reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;); ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
...@@ -436,8 +436,7 @@ struct FusedMoeKernel ...@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window, u_gtile_window,
d_gtile_window, d_gtile_window,
o_gtile_window, o_gtile_window,
scale, scale);
smem_ptr);
tile_id += gridDim.x; tile_id += gridDim.x;
} }
......
...@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col // this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex() CK_TILE_HOST_DEVICE static auto GetAIndex()
{ {
constexpr auto a_dist = Policy::template MakeAGlobalTileDistribution<Problem>(); constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index(); const auto a_coord = a_dist.calculate_index();
return a_coord; return a_coord;
} }
...@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView& o_gtile_window_tmp, OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr, // const void * sorted_weight_ptr,
ScaleDataType scale, ScaleDataType scale,
void* smem_ptr, CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size, index_t dim_size,
index_t hidden_size) index_t hidden_size)
{ {
...@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(), a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(), a_gtile_window_tmp.get_window_origin(),
Policy::template MakeAGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_gtile_window = auto g_gtile_window =
make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
g_gtile_window_tmp.get_window_lengths(), g_gtile_window_tmp.get_window_lengths(),
g_gtile_window_tmp.get_window_origin(), g_gtile_window_tmp.get_window_origin(),
Policy::template MakeGGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
auto u_gtile_window = auto u_gtile_window =
make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
u_gtile_window_tmp.get_window_lengths(), u_gtile_window_tmp.get_window_lengths(),
u_gtile_window_tmp.get_window_origin(), u_gtile_window_tmp.get_window_origin(),
Policy::template MakeUGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_U<Problem>());
auto d_gtile_window = auto d_gtile_window =
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(), d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(), d_gtile_window_tmp.get_window_origin(),
Policy::template MakeDGlobalTileDistribution<Problem>()); Policy::template MakeGlobalTileDistribution_D<Problem>());
auto o_gtile_window = auto o_gtile_window =
make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(), make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
...@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset; auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
make_tile_window(make_tensor_view<address_space_enum::lds>( auto smem_0_window = make_tile_window(
a_smem_ptr, Policy::template MakeALdsStoreBlockDescriptor<Problem>()), make_tensor_view<address_space_enum::lds>(
Policy::template MakeALdsStoreBlockDescriptor<Problem>().get_lengths(), smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()),
{0, 0}); Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), a_gtile_window); async_load_tile(k_lds_store(LdsSeq.at(number<0>{})));
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {} for(index_t i_0 = 0; i_0 < loops_0; i_0++) {}
} }
...@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -351,8 +353,8 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
return o_acc; return o_acc;
...@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
...@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k0_loops <= 2) if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
async_load_fence(); async_load_fence_raw();
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
...@@ -12,21 +12,22 @@ ...@@ -12,21 +12,22 @@
namespace ck_tile { namespace ck_tile {
template <bool GateUpPreShuffled_ = false, enum class FusedMoePermuteStyle
bool DownPreShuffled_ = false, {
index_t NumPrefetchA_ = 2, // permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
index_t NumPrefetchG_ = 2, // permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
index_t NumPrefetchU_ = 2, permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
index_t NumPrefetchD_ = 2, permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> no_permute = 999,
};
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits struct FusedMoeTraits
{ {
static constexpr bool GateUpPreShuffled = GateUpPreShuffled_; static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr bool DownPreShuffled = DownPreShuffled_; static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t NumPrefetchA = NumPrefetchA_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
static constexpr index_t NumPrefetchG = NumPrefetchG_;
static constexpr index_t NumPrefetchU = NumPrefetchU_;
static constexpr index_t NumPrefetchD = NumPrefetchD_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment