Commit 199f7f71 authored by carlushuang's avatar carlushuang
Browse files

modify moe

parent 33ceea62
...@@ -282,8 +282,8 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -282,8 +282,8 @@ struct BlockFmhaPipelineQRKSVSAsync
store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse)); store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
} }
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0) buffer_load_fence_raw(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?) // otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it // Note: here occ are all cleard, return it
return o_acc; return o_acc;
...@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer()); buffer_load_fence_raw(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32 (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q); // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
...@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1) if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access()); async_load_fence_raw(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc, gemm_0(s_acc,
...@@ -381,7 +381,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -381,7 +381,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(k0_loops <= 2) if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
async_load_fence(); async_load_fence_raw();
__builtin_amdgcn_s_barrier(); __builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile const auto bias_tile = load_tile(bias_dram_window); // load bias tile
......
...@@ -31,8 +31,12 @@ struct WarpGemmImpl ...@@ -31,8 +31,12 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>; using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>; using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const template <typename CTensor, typename ATensor, typename BTensor>
CK_TILE_DEVICE void operator()(CTensor& c, const ATensor& a, const BTensor& b) const
{ {
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>; using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>; using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
...@@ -49,8 +53,11 @@ struct WarpGemmImpl ...@@ -49,8 +53,11 @@ struct WarpGemmImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec); c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
} }
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{ {
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CWarpTensor c; CWarpTensor c;
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment