Commit 1ba8a08f authored by carlushuang's avatar carlushuang
Browse files

update tmp work

parent bf214665
......@@ -252,7 +252,7 @@ struct FusedMoeKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
__shared__ char smem_ptr[GetSmemSize()];
// __shared__ char smem_ptr[GetSmemSize()];
ck_tile::index_t num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const ck_tile::index_t*>(kargs.num_sorted_tiles_ptr));
ck_tile::index_t tile_id = __builtin_amdgcn_readfirstlane(blockIdx.x;);
......@@ -436,8 +436,7 @@ struct FusedMoeKernel
u_gtile_window,
d_gtile_window,
o_gtile_window,
scale,
smem_ptr);
scale);
tile_id += gridDim.x;
}
......
......@@ -117,7 +117,7 @@ struct BlockFmhaPipelineQRKSVSAsync
// this is the thread-offset along row/col
CK_TILE_HOST_DEVICE static auto GetAIndex()
{
constexpr auto a_dist = Policy::template MakeAGlobalTileDistribution<Problem>();
constexpr auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
const auto a_coord = a_dist.calculate_index();
return a_coord;
}
......@@ -142,7 +142,8 @@ struct BlockFmhaPipelineQRKSVSAsync
OGlobalTensorView& o_gtile_window_tmp,
// const void * sorted_weight_ptr,
ScaleDataType scale,
void* smem_ptr,
CK_TILE_LDS_ADDR void* smem_0,
CK_TILE_LDS_ADDR void* smem_1,
index_t dim_size,
index_t hidden_size)
{
......@@ -153,25 +154,25 @@ struct BlockFmhaPipelineQRKSVSAsync
make_tile_window(a_gtile_window_tmp.get_bottom_tensor_view(),
a_gtile_window_tmp.get_window_lengths(),
a_gtile_window_tmp.get_window_origin(),
Policy::template MakeAGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_A<Problem>());
auto g_gtile_window =
make_tile_window(g_gtile_window_tmp.get_bottom_tensor_view(),
g_gtile_window_tmp.get_window_lengths(),
g_gtile_window_tmp.get_window_origin(),
Policy::template MakeGGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_G<Problem>());
auto u_gtile_window =
make_tile_window(u_gtile_window_tmp.get_bottom_tensor_view(),
u_gtile_window_tmp.get_window_lengths(),
u_gtile_window_tmp.get_window_origin(),
Policy::template MakeUGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_U<Problem>());
auto d_gtile_window =
make_tile_window(d_gtile_window_tmp.get_bottom_tensor_view(),
d_gtile_window_tmp.get_window_lengths(),
d_gtile_window_tmp.get_window_origin(),
Policy::template MakeDGlobalTileDistribution<Problem>());
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto o_gtile_window =
make_tile_window(o_gtile_window_tmp.get_bottom_tensor_view(),
......@@ -187,12 +188,13 @@ struct BlockFmhaPipelineQRKSVSAsync
auto a_smem_ptr = reinterpret_cast<ADataType*>(smem_ptr) + a_smem_offset;
make_tile_window(make_tensor_view<address_space_enum::lds>(
a_smem_ptr, Policy::template MakeALdsStoreBlockDescriptor<Problem>()),
Policy::template MakeALdsStoreBlockDescriptor<Problem>().get_lengths(),
auto smem_0_window = make_tile_window(
make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsStoreBlockDescriptor_A<Problem>()),
Policy::template MakeLdsStoreBlockDescriptor_A<Problem>().get_lengths(),
{0, 0});
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), a_gtile_window);
async_load_tile(k_lds_store(LdsSeq.at(number<0>{})));
for(index_t i_0 = 0; i_0 < loops_0; i_0++) {}
}
......@@ -351,7 +353,7 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
......@@ -403,7 +405,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
......@@ -428,7 +430,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
......@@ -450,7 +452,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
async_load_fence_raw();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
......@@ -12,21 +12,22 @@
namespace ck_tile {
template <bool GateUpPreShuffled_ = false,
bool DownPreShuffled_ = false,
index_t NumPrefetchA_ = 2,
index_t NumPrefetchG_ = 2,
index_t NumPrefetchU_ = 2,
index_t NumPrefetchD_ = 2,
enum class FusedMoePermuteStyle
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
no_permute = 999,
};
template <bool DownPreShuffled_ = false,
FusedMoePermuteStyle PermuteStyle_ = FusedMoePermuteStyle::permute_b_nr_kr_kw_nw_kv,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct FusedMoeTraits
{
static constexpr bool GateUpPreShuffled = GateUpPreShuffled_;
static constexpr bool DownPreShuffled = DownPreShuffled_;
static constexpr index_t NumPrefetchA = NumPrefetchA_;
static constexpr index_t NumPrefetchG = NumPrefetchG_;
static constexpr index_t NumPrefetchU = NumPrefetchU_;
static constexpr index_t NumPrefetchD = NumPrefetchD_;
static constexpr FusedMoePermuteStyle PermuteStyle = PermuteStyle_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment