Commit 7ccdbe16 authored by carlushuang's avatar carlushuang
Browse files

update

parent e2a318bc
......@@ -20,7 +20,7 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
{
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
fused_moegemm_<t_>(s, a);
r = fused_moegemm_<t_>(s, a);
}
// clang-format on
return r;
......
......@@ -121,8 +121,6 @@ auto create_args(int argc, char* argv[])
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::cout << "xxxx " << __LINE__ << std::flush << std::endl;
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
......@@ -173,7 +171,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec_str << "]"
<< " t:" << tokens << ", e:" << experts << ", k:" << topk << ", st:" << stride
<< ", hidden:" << hidden_size << ", interm:" << intermediate_size << ", tp:" << tp
<< ", go:" << gate_only << ", q:" << fused_quant << std::flush;
<< ", shared_interm:" << shared_intermediate_size_0 << "|"
<< shared_intermediate_size_1 << ", go:" << gate_only << ", q:" << fused_quant
<< std::flush;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType;
......@@ -191,7 +191,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, shared_intermediate_size_1, hidden_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
ck_tile::HostTensor<GScaleDataType> sg_host({shared_intermediate_size_0});
......@@ -207,6 +207,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
{(max_num_tokens_padded + block_m - 1) / block_m});
ck_tile::HostTensor<IndexDataType> num_sorted_tiles_host({1});
#if 1
#if 1
ck_tile::FillStepRange<ADataType>{-.5f, .5f, 0.01f}(a_host);
ck_tile::FillStepRange<GDataType>{-.5f, .5f, 0.01f}(g_host);
ck_tile::FillStepRange<DDataType, false>{.5f, -.5f, -0.01f}(d_host);
ck_tile::FillStepRange<AScaleDataType>{0.f, 1.f, 0.01f}(sa_host);
ck_tile::FillStepRange<GScaleDataType>{0.f, 1.f, 0.01f}(sg_host);
ck_tile::FillStepRange<DScaleDataType>{0.f, 1.f, 0.01f}(sd_host);
ck_tile::FillStepRange<YSmoothScaleDataType>{0.f, 1.f, 0.01f}(sy_host);
ck_tile::FillStepRange<TopkWeightDataType>{-.5f, .5f, 0.01f}(topk_weight_host);
#else
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f}(d_host);
......@@ -215,6 +226,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f}(sd_host);
ck_tile::FillUniformDistribution<YSmoothScaleDataType>{-.5f, .5f}(sy_host);
ck_tile::FillUniformDistribution<TopkWeightDataType>{-.5f, .5f}(topk_weight_host);
#endif
// permute weight
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
......@@ -236,6 +248,77 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
topid_unique_gen<IndexDataType>(topk_ids_host.mData, tokens, topk, experts, 11913);
}
#else
a_host.loadtxt("../../ater/input_torch.txt");
topk_ids_host.loadtxt("../../ater/topk_ids_torch.txt", "int");
// topk_ids_host.savetxt("topk_ids_2.txt");
topk_weight_host.loadtxt("../../ater/topk_weights_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
g_host.loadtxt("../../ater/w1_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
d_host.loadtxt("../../ater/w2_torch.txt", "float");
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
topk_weight_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host.mData[0],
experts,
block_m);
std::cout << "------- @@@ " << __LINE__ << std::flush << std::endl;
std::cout << sorted_token_ids_host << std::endl;
std::cout << num_sorted_tiles_host << std::endl;
std::cout << sorted_expert_ids_host << std::endl;
ck_tile::reference_fused_moe<AccDataType, ck_tile::element_wise::Gelu>(
a_host,
g_host,
d_host,
sa_host,
sg_host,
sd_host,
sy_host,
o_host,
sorted_token_ids_host,
sorted_weight_host,
sorted_expert_ids_host,
num_sorted_tiles_host,
topk_ids_host,
block_m,
tokens,
experts,
hidden_size,
shared_intermediate_size_0,
topk,
gate_only);
std::cout << "------- >" << std::endl;
std::cout << o_host << std::endl;
(void)balance;
{
ck_tile::HostTensor<ODataType> o_host_torch({tokens, hidden_size}, {stride, 1});
o_host_torch.loadtxt("../../ater/ref2_torch.txt");
auto [rtol, atol] = get_elimit<ADataType>();
bool pass = ck_tile::check_err(
o_host, o_host_torch, std::string("OUT-Torch Error: Incorrect results!"), rtol, atol);
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
return 1;
#endif
ck_tile::reference_moe_sorting<TopkWeightDataType, IndexDataType>(
topk_ids_host,
......@@ -247,8 +330,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts,
block_m);
// std::cout << sorted_token_ids_host << std::endl;
// std::cout << num_sorted_tiles_host << std::endl;
std::cout << sorted_token_ids_host << std::endl;
std::cout << num_sorted_tiles_host << std::endl;
std::cout << sorted_expert_ids_host << std::endl;
std::cout << topk_weight_host << std::endl;
std::cout << sorted_weight_host << std::endl;
// done, preparing GPU buffer
ck_tile::DeviceMem a_buf(a_host);
......
......@@ -102,4 +102,28 @@ CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
#endif
}
template <typename T>
CK_TILE_DEVICE auto flag_to_exec(const T& v_flag)
{
static_assert(sizeof(T) == 4);
// per-thread v_flag store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_ge_u32 %[s_exec_flag], %[v_flag], 1"
: [s_exec_flag] "=s"(exec_flag)
: [v_flag] "v"(v_flag));
return exec_flag;
}
template <typename X, typename Y>
CK_TILE_DEVICE auto cmp_lt_to_exec(const X& x, const Y& y)
{
static_assert(sizeof(X) == 4 && sizeof(Y) == 4);
// per-thread cmp store into 2x sgpr
uint32x2_t exec_flag;
asm volatile("v_cmp_lt_u32 %[s_exec_flag], %[v_x], %[v_y]"
: [s_exec_flag] "=s"(exec_flag)
: [v_x] "v"(x), [v_y] "v"(y));
return exec_flag;
}
} // namespace ck_tile
......@@ -235,6 +235,44 @@ struct FillMonotonicSeq
}
};
template <typename T, bool IsAscending = true>
struct FillStepRange
{
float start_value_{0};
float end_value_{3};
float step_{1};
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
std::generate(first, last, [=, n = start_value_]() mutable {
auto tmp = n;
n += step_;
if constexpr(IsAscending)
{
if(n > end_value_)
n = start_value_;
}
else
{
if(n < end_value_)
n = start_value_;
}
return type_convert<T>(tmp);
});
}
template <typename ForwardRange>
auto operator()(ForwardRange&& range) const -> std::void_t<
decltype(std::declval<const FillStepRange&>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
};
template <typename T>
struct FillConstant
{
......
......@@ -12,6 +12,7 @@
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
......@@ -589,7 +590,7 @@ struct HostTensor
return ck_tile::span<Element>{reinterpret_cast<Element*>(data()),
size() * FromSize / ToSize};
}
#if 1
friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
{
os << t.mDesc;
......@@ -600,11 +601,90 @@ struct HostTensor
{
os << ", ";
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
}
else
{
os << t.mData[idx];
}
}
os << "]";
return os;
}
#endif
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void loadtxt(std::string file_name, std::string dtype = "float")
{
std::ifstream file(file_name);
if(file.is_open())
{
std::string line;
index_t cnt = 0;
while(std::getline(file, line))
{
if(cnt >= static_cast<index_t>(mData.size()))
{
throw std::runtime_error(std::string("data read from file:") + file_name +
" is too big");
}
if(dtype == "float")
{
mData[cnt] = type_convert<T>(std::stof(line));
}
else if(dtype == "int" || dtype == "int32")
{
mData[cnt] = type_convert<T>(std::stoi(line));
}
cnt++;
}
file.close();
if(cnt < static_cast<index_t>(mData.size()))
{
std::cerr << "Warning! reading from file:" << file_name
<< ", does not match the size of this tensor" << std::endl;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void savetxt(std::string file_name)
{
std::ofstream file(file_name);
if(file.is_open())
{
for(auto& itm : mData)
{
file << itm << std::endl;
}
file.close();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw std::runtime_error(std::string("unable to open file:") + file_name);
}
}
Descriptor mDesc;
Data mData;
......
......@@ -53,12 +53,12 @@ template <typename AccDataType, // you only need to explcitly set this one
typename IndexDataType>
void reference_fused_moe(
const ck_tile::HostTensor<ADataType>& a_host, // [tokens, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, hidden_size]
const ck_tile::HostTensor<GDataType>& g_host, // [experts, interme_size_0, hidden_size]
const ck_tile::HostTensor<DDataType>& d_host, // [experts, hidden_size, interme_size_1]
const ck_tile::HostTensor<AScaleDataType>& sa_host, // [tokens, 1],
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size]
const ck_tile::HostTensor<GScaleDataType>& sg_host, // [experts, 1, interme_size_0]
const ck_tile::HostTensor<DScaleDataType>& sd_host, // [experts, 1, hidden_size],
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size]
const ck_tile::HostTensor<YSmoothScaleDataType>& sy_host, // [experts, 1, interme_size_0]
ck_tile::HostTensor<ODataType>& o_host, // [tokens, hidden_size]
const ck_tile::HostTensor<IndexDataType>& sorted_token_ids_host, // [max_num_tokens_padded]
const ck_tile::HostTensor<TopkWeightDataType>& sorted_weight_host, // [max_num_tokens_padded]
......@@ -73,7 +73,7 @@ void reference_fused_moe(
ck_tile::index_t tokens,
ck_tile::index_t experts,
ck_tile::index_t hidden_size,
ck_tile::index_t intermediate_size,
ck_tile::index_t intermediate_size, // this size is for gate/up
ck_tile::index_t topk,
ck_tile::index_t gate_only)
{
......@@ -81,7 +81,9 @@ void reference_fused_moe(
assert(sorted_weight_host.get_num_of_dimension() == 1);
assert(sorted_expert_ids_host.get_num_of_dimension() == 1);
assert(num_sorted_tiles_host.get_element_size() == 1);
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0];
ck_tile::index_t num_sorted_tiles = num_sorted_tiles_host.mData[0] / block_m;
ck_tile::index_t intermediate_size_0 = intermediate_size;
ck_tile::index_t intermediate_size_1 = intermediate_size / (gate_only ? 1 : 2);
// TODO: better remove this in the future, or modify the token_id value
auto get_topk_id = [&](ck_tile::index_t token_id_, ck_tile::index_t expert_id_) {
......@@ -90,6 +92,7 @@ void reference_fused_moe(
if(token_ids_host(token_id_, i_) == expert_id_)
return i_;
}
throw std::runtime_error("not correct token/expert pair\n");
return -1; // TODO: not correct!!
};
......@@ -108,9 +111,9 @@ void reference_fused_moe(
ck_tile::index_t i_topk = get_topk_id(i_token, i_expert); // TODO: ugly
auto weight = sorted_weight_host.mData[i_flatten];
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size});
ck_tile::HostTensor<AccDataType> acc_0({1, intermediate_size_0});
// first gemm
for(ck_tile::index_t i_n = 0; i_n < intermediate_size; i_n++)
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_0; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
......@@ -121,32 +124,38 @@ void reference_fused_moe(
acc_0(0, i_n) = acc;
}
ck_tile::HostTensor<AccDataType> y({1, hidden_size});
ck_tile::HostTensor<AccDataType> y({1, intermediate_size_1});
if(gate_only)
{
assert(hidden_size == intermediate_size);
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
if(intermediate_size_1 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
Activation{}(y(0, i_n), acc_0(0, i_n));
}
}
else
{
assert(hidden_size * 2 == intermediate_size);
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
if(intermediate_size_1 * 2 != intermediate_size_0)
throw std::runtime_error(
"intermediate_size not correct, 0:" + std::to_string(intermediate_size_0) +
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
AccDataType tmp;
Activation{}(tmp, acc_0(0, i_n));
y(0, i_n) = tmp * acc_0(0, i_n + hidden_size); // TODO: elementwise mul
y(0, i_n) = tmp * acc_0(0, i_n + intermediate_size_1); // TODO: elementwise mul
}
}
// second gemm
// second gemm, loop along gemm-n
ck_tile::HostTensor<AccDataType> acc_1({1, hidden_size});
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
AccDataType acc = static_cast<AccDataType>(0);
for(ck_tile::index_t i_k = 0; i_k < hidden_size; i_k++)
for(ck_tile::index_t i_k = 0; i_k < intermediate_size_1; i_k++)
{
acc += y(0, i_k) * type_convert<AccDataType>(d_host(i_expert, i_n, i_k));
}
......@@ -165,12 +174,12 @@ void reference_fused_moe(
auto r = [&](auto i_token) {
for(ck_tile::index_t i_n = 0; i_n < hidden_size; i_n++)
{
ODataType acc = type_convert<ODataType>(0);
AccDataType acc = type_convert<ODataType>(0);
for(ck_tile::index_t i_topk = 0; i_topk < topk; i_topk++)
{
acc += out_topk_tokens(i_token, i_topk, i_n);
}
o_host(i_token, i_n) = acc;
o_host(i_token, i_n) = type_convert<ODataType>(acc);
}
};
make_ParallelTensorFunctor(r, tokens)(std::thread::hardware_concurrency());
......
......@@ -41,6 +41,17 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
using BDataType = bf16_t;
using ODataType = bf16_t;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// y y p p p y
// reg before shfl M0(2)*N0(2)*Nl(4)*Nw(4)*Mw(16)*Nv(4)
// but order is N0*M0*Nv
// in LDS we need store as
// M0(2)* N0(2) * Nl(4) * Nw(4) * (Mw(16)*Nv(4) + 4)
// y y wave-id lid/16 lid%16 v
return 2 * 2 * 4 * 4 * (16 * 4 + 4) * sizeof(bf16_t);
}
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
// template <typename AWindow, typename BWindow, typename OWindow, typename ScaleTensor>
......@@ -48,30 +59,26 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
typename BCoords,
typename ORes,
typename OCoords,
typename OFlags,
typename ScaleTensor>
CK_TILE_DEVICE auto
operator()(const BRes& res_b,
const BCoords& cached_coords_b,
const ORes& res_o,
const OCoords& cached_coords_o,
const OFlags& o_flags, // this should be in sgpr
CK_TILE_LDS_ADDR void* smem,
// OWindow& o_window_,
index_t n, // loop along n dim
const ScaleTensor& scale_,
index_t stride_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t stride_o)
index_t tile_offset_b, // stride b is fixed to blockKr * blockW, but still can adjust
index_t tile_offset_o)
{
// auto cached_coords_b = b_window_.cached_coords_;
// auto res_b =
// b_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_; auto
// cached_coords_o = o_window_.cached_coords_; auto res_o =
// o_window_.get_bottom_tensor_view().get_buffer_view().cached_buf_res_;
static_assert(BCoords::size() == 8); // 8
static_assert(OCoords::size() == 8);
const index_t stride_b_bytes = stride_b * sizeof(BDataType);
const index_t stride_o_bytes = stride_o * sizeof(ODataType);
const index_t tile_stride_b_bytes = tile_offset_b * sizeof(BDataType);
const index_t tile_stride_o_bytes = tile_offset_o * sizeof(ODataType);
static_assert(ScaleTensor::size() == 2);
float s0 = scale_[number<0>{}];
......@@ -143,6 +150,7 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
asm volatile(
";-------------------------------------------------------------\n"
" s_mov_b32 s52, 0x07060302 ; v_perm\n"
" s_mov_b64 s[38:39], exec ; save current exec\n"
" s_mov_b32 s8, %[s_res_o0] \n"
" s_mov_b32 s9, %[s_res_o1] \n"
" s_mov_b32 s12, %[s_res_b0] \n"
......@@ -247,10 +255,9 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[12:15], 0 offen offset:2048 \n"
" buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[12:15], 0 offen offset:3072 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_waitcnt vmcnt(24) \n"
"L_start%=: \n"
" s_waitcnt vmcnt(32) \n"
" s_barrier \n"
......@@ -517,39 +524,37 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" ds_read_b32 %[c6], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c7], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
//" s_mov_b64 exec, s[16:17] \n"
// "s_endpgm\n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c0], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[18:19] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c1], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[20:21] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c2], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c3], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c4], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c5], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[28:29] \n"
// "s_endpgm\n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c6], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c7], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_stride_o], s8 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
//" s_addk_i32 s80, 0x0080 \n"
//" s_cmp_lt_i32 s80, s81 \n"
//" s_cbranch_scc0 label_0E98 \n"
......@@ -817,38 +822,31 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
" ds_read_b32 %[c22], %[v_sfl_sld] offset:4416 + %[shfl_base] \n"
" ds_read_b32 %[c23], %[v_sfl_sld] offset:4448 + %[shfl_base] \n"
" s_waitcnt lgkmcnt(0) \n"
//" s_mov_b64 exec, s[16:17] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c16], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[18:19] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c17], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[20:21] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c18], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[22:23] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c19], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[24:25] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c20], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[26:27] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c21], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[28:29] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c22], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
//" s_mov_b64 exec, s[30:31] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c23], s[8:9] \n"
//" s_mov_b64 exec, s[36:37] \n"
" s_mov_b64 exec, %[s_execflag_0] \n"
" global_atomic_pk_add_bf16 %[v_os_o0], %[c0], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_1] \n"
" global_atomic_pk_add_bf16 %[v_os_o1], %[c1], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_2] \n"
" global_atomic_pk_add_bf16 %[v_os_o2], %[c2], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_3] \n"
" global_atomic_pk_add_bf16 %[v_os_o3], %[c3], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_4] \n"
" global_atomic_pk_add_bf16 %[v_os_o4], %[c4], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_5] \n"
" global_atomic_pk_add_bf16 %[v_os_o5], %[c5], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_6] \n"
" global_atomic_pk_add_bf16 %[v_os_o6], %[c6], s[8:9] \n"
" s_mov_b64 exec, %[s_execflag_7] \n"
" global_atomic_pk_add_bf16 %[v_os_o7], %[c7], s[8:9] \n"
" s_mov_b64 exec, s[38:39] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 ; k-- \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s12, s86, s12 \n"
" s_addc_u32 s13, 0, s13 \n"
" s_add_u32 s8, %[s_stride_o], s8 \n"
" s_add_u32 s8, %[s_tile_os_o], s8 \n"
" s_addc_u32 s9, 0, s9 \n"
" s_branch L_start%= \n"
"L_end%=: \n"
......@@ -917,13 +915,22 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
[v_os_b6]"v"(static_cast<index_t>(cached_coords_b[number<6>{}] * sizeof(BDataType))),
[v_os_b7]"v"(static_cast<index_t>(cached_coords_b[number<7>{}] * sizeof(BDataType))),
[s_stride_o]"s"(stride_o_bytes),
[s_stride_b]"s"(stride_b_bytes),
[s_tile_os_o]"s"(tile_stride_o_bytes),
[s_tile_os_b]"s"(tile_stride_b_bytes),
[scale_0]"v"(s0),
[scale_1]"v"(s1),
[v_nan_lo]"v"(nan_lo),
[v_nan_hi]"v"(nan_hi)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
[v_nan_hi]"v"(nan_hi),
[s_execflag_0]"s"(o_flags[number<0>{}]),
[s_execflag_1]"s"(o_flags[number<1>{}]),
[s_execflag_2]"s"(o_flags[number<2>{}]),
[s_execflag_3]"s"(o_flags[number<3>{}]),
[s_execflag_4]"s"(o_flags[number<4>{}]),
[s_execflag_5]"s"(o_flags[number<5>{}]),
[s_execflag_6]"s"(o_flags[number<6>{}]),
[s_execflag_7]"s"(o_flags[number<7>{}])
:
"memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
"a30", "a31", "a32", "a33", "a34", "a35", "a36", "a37", "a38", "a39",
......@@ -953,9 +960,13 @@ struct FlatmmSnUK_GFX9_32x128x512_1x4x1_16x16x16_BF16
"a236", "a237", "a238", "a239", "a240", "a241", "a242", "a243",
"a244", "a245", "a246", "a247", "a248", "a249", "a250", "a251",
"a252", "a253", "a254", "a255",
"s16", "s17", "s18", "s19", "s20", "s21", "s22", "s23",
"s8", "s9", "s12", "s13", "s14", "s15", "s38", "s39", "s52", "s86",
// "s32", "s33",
"v50", "v54", "v55",
"v64","v65","v66","v67","v68","v69","v70","v71",
"v72","v73","v74","v75","v76","v77","v78","v79",
"v80","v81","v82","v83","v84","v85","v86","v87",
"v88","v89","v90","v91","v92","v93","v94","v95",
"v128", "v129", "v130", "v131",
"v132", "v133", "v134", "v135", "v136", "v137", "v138", "v139",
"v140", "v141", "v142", "v143", "v144", "v145", "v146", "v147",
......
......@@ -241,17 +241,13 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
return enc_{};
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return 32 * (128 + 8) * sizeof(bf16_t);
}
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
#if 0
template <typename AWindow, typename BWindow, typename SmemWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const BWindow& b_window_,
SmemWindow& smem_window_,
index_t k,
index_t stride_a,
index_t stride_b) // stride b is fixed to blockKr * blockW, but still can adjust
#else
template <typename ARes, typename ACoords, typename BRes, typename BCoords>
CK_TILE_DEVICE auto
operator()(const ARes& res_a,
......@@ -260,9 +256,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
const BCoords& cached_coords_b,
CK_TILE_LDS_ADDR void* smem,
index_t k,
index_t stride_a,
index_t stride_b) // stride b is fixed to blockKr * blockW, but still can adjust
#endif
index_t tile_offset_a, // for each tile, the offset to move for each unroll
index_t tile_offset_b) // for each tile, the offset to move for each unroll
{
static_assert(ACoords::size() == Block_M * Block_K / BlockSize / 2 /*2x per dword*/); // 8
static_assert(BCoords::size() == Repeat_N);
......@@ -292,8 +287,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
make_static_tile_distribution(a_block_dstr_encode));
}();
const index_t stride_a_bytes = stride_a * sizeof(bf16_t);
const index_t stride_b_bytes = stride_b * sizeof(bf16_t);
const index_t tile_offset_a_bytes = tile_offset_a * sizeof(bf16_t);
const index_t tile_offset_b_bytes = tile_offset_b * sizeof(bf16_t);
const auto [m0_init_value, size_per_issue] = get_async_store_smem_info(a_sst);
constexpr auto smem_buf_size =
......@@ -343,9 +338,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[smem_sz], %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move a with cond \n"
"s_cselect_b32 s86, %[s_stride_a], 0 \n"
"s_add_u32 s16, s86, s16 \n"
"s_addc_u32 s17, 0, s17 \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch A1\n"
"buffer_load_dword %[v_os_a0], s[16:19], 0 offen lds \n"
"s_add_u32 m0, %[s_size_per_issue], m0 \n"
......@@ -364,9 +359,9 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dword %[v_os_a7], s[16:19], 0 offen lds \n"
"s_add_u32 m0, 0, %[s_m0_init] \n"
"s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
"s_cselect_b32 s86, %[s_stride_a], 0 \n"
"s_add_u32 s16, s86, s16 \n"
"s_addc_u32 s17, 0, s17 \n"
"s_cselect_b32 s86, %[s_tile_os_a], 0 ; move a with cond \n"
"s_add_u32 s16, s86, s16 ; move a with cond \n"
"s_addc_u32 s17, 0, s17 ; move a with cond \n"
"; -- prefetch B0\n"
"buffer_load_dwordx4 acc[0:3], %[v_os_b0], s[20:23], 0 offen \n"
"buffer_load_dwordx4 acc[4:7], %[v_os_b0], s[20:23], 0 offen offset:1024 \n"
......@@ -401,19 +396,19 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
"buffer_load_dwordx4 acc[120:123], %[v_os_b7], s[20:23], 0 offen offset:2048 \n"
"buffer_load_dwordx4 acc[124:127], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
"s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
"s_cselect_b32 s86, %[s_stride_b], 0 \n"
"s_add_u32 s20, s86, s20 \n"
"s_addc_u32 s21, 0, s21 \n"
"s_waitcnt vmcnt(40)\n"
"s_cselect_b32 s86, %[s_tile_os_b], 0 ; move b with cond \n"
"s_add_u32 s20, s86, s20 ; move b with cond \n"
"s_addc_u32 s21, 0, s21 ; move b with cond \n"
"s_waitcnt vmcnt(40) \n"
"s_barrier \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0] \n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1] \n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2] \n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3] \n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4] \n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5] \n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6] \n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7] \n"
"ds_read_b128 v[64:67], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_0]\n" // 1024: N stride, 64 K stride
"ds_read_b128 v[68:71], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_1]\n"
"ds_read_b128 v[72:75], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_2]\n"
"ds_read_b128 v[76:79], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_3]\n"
"ds_read_b128 v[80:83], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_4]\n"
"ds_read_b128 v[84:87], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_5]\n"
"ds_read_b128 v[88:91], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_6]\n"
"ds_read_b128 v[92:95], %[v_os_slda] offset:0*%[smem_sz] + %[sld_os_7]\n"
"L_start%=: \n"
" s_waitcnt vmcnt(24) & lgkmcnt(0) \n"
" s_barrier \n"
......@@ -601,18 +596,18 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[118:119], v[86:87], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[120:121], v[88:89], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[122:123], v[90:91], %[v_acc_15] \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072 \n"
" buffer_load_dwordx4 acc[252:255], %[v_os_b7], s[20:23], 0 offen offset:3072\n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[124:125], v[92:93], %[v_acc_15] \n"
" v_mfma_f32_16x16x16_bf16 %[v_acc_15], acc[126:127], v[94:95], %[v_acc_15] \n"
" s_sub_i32 %[s_loop_cnt], %[s_loop_cnt], 1 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_stride_a], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" ;------------------------------------------ \n"
......@@ -809,11 +804,11 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
" s_cmp_gt_i32 %[s_loop_cnt] 0 \n"
" s_cbranch_scc0 L_end%= \n"
" s_cmp_gt_i32 %[s_loop_cnt] 2 ; move a with cond \n"
" s_cselect_b32 s86, %[s_stride_a], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_a], 0 \n"
" s_add_u32 s16, s86, s16 \n"
" s_addc_u32 s17, 0, s17 \n"
" s_cmp_gt_i32 %[s_loop_cnt] 1 ; move b with cond \n"
" s_cselect_b32 s86, %[s_stride_b], 0 \n"
" s_cselect_b32 s86, %[s_tile_os_b], 0 \n"
" s_add_u32 s20, s86, s20 \n"
" s_addc_u32 s21, 0, s21 \n"
" s_branch L_start%= \n"
......@@ -875,8 +870,8 @@ struct FlatmmUK_GFX9_32x512x128_1x4x1_16x16x16_BF16
[sld_os_5]"n"(sld_os[number<5>{}].value),
[sld_os_6]"n"(sld_os[number<6>{}].value),
[sld_os_7]"n"(sld_os[number<7>{}].value),
[s_stride_a]"s"(stride_a_bytes),
[s_stride_b]"s"(stride_b_bytes)
[s_tile_os_a]"s"(tile_offset_a_bytes),
[s_tile_os_b]"s"(tile_offset_b_bytes)
: "memory", "a0", "a1", "a2", "a3", "a4", "a5", "a6", "a7", "a8", "a9",
"a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19",
"a20", "a21", "a22", "a23", "a24", "a25", "a26", "a27", "a28", "a29",
......
......@@ -153,9 +153,25 @@ struct FusedMoeGemmKernel
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
return "";
using S_ = BlockShape;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, GDataType>) {
base_str += _SS_("_") + _SS_(t2s<GDataType>::name);
}
return base_str;
}();
return _SS_("fused_moe_") + _SS_(prec_str) + "_" +
_TS_(S_::Block_M0) + "x" + _TS_(S_::Block_N0) + "x" + _TS_(S_::Block_K0) + "x" + _TS_(S_::Block_N1) + "_" +
_TS_(S_::WarpPerBlock_M0) + "x" + _TS_(S_::WarpPerBlock_N0) + "x" + _TS_(S_::WarpPerBlock_K0) + "_" +
_TS_(S_::Warp_M0) + "x" + _TS_(S_::Warp_N0) + "x" + _TS_(S_::Warp_K0) + "_" + _SS_(Pipeline::name);
#undef _SS_
#undef _TS_
// clang-format on
}
......@@ -199,16 +215,13 @@ struct FusedMoeGemmKernel
constexpr index_t block_m = BlockShape::Block_M0;
int max_num_tokens_padded =
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(BlockSize_); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
// return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
return Pipeline::GetSmemSize();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
......@@ -222,6 +235,11 @@ struct FusedMoeGemmKernel
const auto [sorted_tile_id, intermediate_tile_id] =
Partitioner{}(num_sorted_tiles, kargs.intermediate_size);
// if(threadIdx.x == 0)
// printf("bid:%d,%d, num_sorted_tiles:%d, sorted_tile_id:%d(%d),
// intermediate_tile_id:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), num_sorted_tiles, sorted_tile_id, sorted_tile_id >=
// num_sorted_tiles? 1 : 0, intermediate_tile_id);
if(sorted_tile_id >= num_sorted_tiles)
return;
......
......@@ -66,17 +66,15 @@ struct FusedMoeGemmPipeline_FlatmmUk
}
}();
static constexpr const char* name = "fused_moe_flatmm_uk";
// TODO: there are multiple buffers
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize_A()
{
return Policy::template GetSmemSize_A<Problem>();
}
static constexpr const char* name = "flatmm_uk";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
constexpr index_t smem_0 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_1 = Policy::template GetUK_1<Problem>().GetSmemSize();
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_0, max(smem_1, smem_bridge));
}
// this is the thread-offset along row/col
......@@ -154,7 +152,7 @@ struct FusedMoeGemmPipeline_FlatmmUk
{
constexpr index_t n_size = coords.size();
array<index_t, n_size> w;
array<TopkWeightDataType, n_size> w;
static_for<0, n_size, 1>{}([&](auto i) {
w.at(i) = sorted_weight_ptr[coords[i]]; // base_coord + i * MLans;
});
......@@ -207,34 +205,49 @@ struct FusedMoeGemmPipeline_FlatmmUk
index_t sorted_tile_id,
index_t intermediate_tile_id)
{
index_t nr_0 = kargs.intermediate_size / BlockShape::Block_Nr0;
index_t kr_0 = kargs.hidden_size / BlockShape::Block_Kr0;
index_t nr_1 = kargs.hidden_size / BlockShape::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / BlockShape::Block_Kr1; // should be same as nr_0
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
ck_tile::index_t shared_intermediate_size_0 = kargs.intermediate_size;
// w1 (Down, N size)
ck_tile::index_t shared_intermediate_size_1 = kargs.intermediate_size / hidden_radio_0;
index_t nr_0 = shared_intermediate_size_0 / BlockShape::Warp_N0; // divide N in W
index_t kr_0 = kargs.hidden_size / BlockShape::Warp_K0; // divide K in W
index_t nr_1 = kargs.hidden_size / BlockShape::Warp_N1;
index_t kr_1 = shared_intermediate_size_1 / BlockShape::Warp_K1;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
index_t expert_stride_0 = shared_intermediate_size_0 * kargs.hidden_size;
index_t expert_stride_1 = shared_intermediate_size_1 * kargs.hidden_size;
// nr*kr*w
index_t interm_idx_nr = __builtin_amdgcn_readfirstlane(
intermediate_tile_id *
BlockShape::Block_Nr0); // intermediate_tile_id * Block_N / (N in W)
index_t interm_idx_nr =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_Nr0);
// printf("bid:%d,%d, sorted_tile_id:%d(, intermediate_tile_id:%d, expert_id:%d,
// interm_idx_nr:%d\n", static_cast<int>(blockIdx.x),
// static_cast<int>(blockIdx.y), sorted_tile_id, intermediate_tile_id, expert_id,
// interm_idx_nr);
auto row_coords_a = GetRowCoords_A(sorted_tile_id * BlockShape::Block_M0);
auto row_ids_a = GetRowID_A(
row_coords_a, reinterpret_cast<const IndexDataType*>(kargs.sorted_token_ids_ptr));
auto a_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
auto a_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_K0 / kAlignmentA) * kAlignmentA;
},
number<row_ids_a.size()>{});
auto a_res =
make_wave_buffer_resource(reinterpret_cast<const ADataType*>(kargs.a_ptr),
kargs.num_tokens * kargs.stride_token * sizeof(ADataType));
const auto g_win = [&]() {
auto g_win = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
interm_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<BlockShape::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<BlockShape::Block_W0>{}, 1),
......@@ -243,8 +256,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
// number<BlockShape::Block_Nr0>{}.fff();
// number<kAlignmentG>{}.zzz();
const auto g_window_ =
make_tile_window_linear(g_view_,
auto g_window_ = make_tile_window_linear_raw(
g_view_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<BlockShape::Block_Kr0>{},
number<BlockShape::Block_W0>{}),
......@@ -271,8 +284,8 @@ struct FusedMoeGemmPipeline_FlatmmUk
number<kAlignmentD>{},
number<1>{});
const auto d_window_ =
make_tile_window_linear(d_view_,
const auto d_window_ = make_tile_window_linear_raw(
d_view_,
make_tuple(number<BlockShape::Block_Nr1>{},
number<BlockShape::Block_Kr1>{},
number<BlockShape::Block_W1>{}),
......@@ -309,14 +322,23 @@ struct FusedMoeGemmPipeline_FlatmmUk
constexpr auto i_nr_ = number<i % Nr_>{};
constexpr auto i_kr0_ = number<i / Nr_>{};
return i_nr_ * kargs.intermediate_size * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
return i_nr_ * shared_intermediate_size_1 * Nw_ * Nl_ + i_kr0_ * Kr1_ * W_ +
base_os_;
},
number<num_offsets_>{});
}();
#endif
auto o_coords = generate_tuple([&](auto i) { return row_ids_a[i] * kargs.stride_token; },
number<a_coords.size()>{});
auto o_coords = generate_tuple(
[&](auto i) {
return row_ids_a[i] * kargs.stride_token +
threadIdx.x % (BlockShape::Block_N1 / kAlignmentO) * kAlignmentO;
},
number<row_ids_a.size()>{});
auto o_flags =
generate_tuple([&](auto i) { return cmp_lt_to_exec(row_ids_a[i], kargs.num_tokens); },
number<row_ids_a.size()>{});
auto bridge_sst_win = [&]() {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
......@@ -332,7 +354,79 @@ struct FusedMoeGemmPipeline_FlatmmUk
auto row_coords_o = GetRowCoords_O(sorted_tile_id * BlockShape::Block_M0);
auto w_scale = GetWeightScale(
row_coords_o, reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr));
#if 0
printf("bid:%d,%d, tid:%d, sorted_tile_id:%d(, intermediate_tile_id:%d, e:%d, "
"interm_idx_nr:%d, coords:a:%d,%d,%d, row_ids_a:%d,%d,%d, (%d)g_coords:%d.%d.%d, "
"o_coords:%d,%d,%d,%d,%d,%d,%d,%d(%d,%d,%d,%d,%d,%d,%d,%d)\n",
static_cast<int>(blockIdx.x),
static_cast<int>(blockIdx.y),
static_cast<int>(threadIdx.x),
sorted_tile_id,
intermediate_tile_id,
expert_id,
interm_idx_nr,
row_coords_a[0],
row_coords_a[1],
row_coords_a[7],
row_ids_a[0],
row_ids_a[1],
row_ids_a[7],
kr_0 * BlockShape::Block_W0,
g_coords[number<0>{}],
g_coords[number<1>{}],
g_coords[number<7>{}],
o_coords[number<0>{}],
o_coords[number<1>{}],
o_coords[number<2>{}],
o_coords[number<3>{}],
o_coords[number<4>{}],
o_coords[number<5>{}],
o_coords[number<6>{}],
o_coords[number<7>{}],
// (row_ids_a[0] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[1] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[2] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[3] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[4] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[5] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[6] >= kargs.num_tokens ? 1 : 0),
// (row_ids_a[7] >= kargs.num_tokens ? 1 : 0)
(row_ids_a[0] < kargs.num_tokens && static_cast<index_t>(o_coords[number<0>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[1] < kargs.num_tokens && static_cast<index_t>(o_coords[number<1>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[2] < kargs.num_tokens && static_cast<index_t>(o_coords[number<2>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[3] < kargs.num_tokens && static_cast<index_t>(o_coords[number<3>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[4] < kargs.num_tokens && static_cast<index_t>(o_coords[number<4>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[5] < kargs.num_tokens && static_cast<index_t>(o_coords[number<5>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[6] < kargs.num_tokens && static_cast<index_t>(o_coords[number<6>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0),
(row_ids_a[7] < kargs.num_tokens && static_cast<index_t>(o_coords[number<7>{}]) >=
(kargs.num_tokens * kargs.stride_token)
? 7777
: 0)
);
#endif
auto uk_0 = Policy::template GetUK_0<Problem>();
auto acc_0 = uk_0(a_res,
a_coords,
......@@ -340,25 +434,29 @@ struct FusedMoeGemmPipeline_FlatmmUk
g_coords,
smem,
kargs.hidden_size,
kargs.stride_token,
BlockShape::Block_Kr0 * BlockShape::Block_W0);
BlockShape::Block_K0, // tile offset for B matrix each unroll
BlockShape::Block_Kr0 *
BlockShape::Block_W0); // tile offset for B matrix each unroll
// return ;
sweep_tile(acc_0,
[&](auto idx) { typename Problem::GateActivation{}(acc_0(idx), acc_0[idx]); });
auto y_pre = cast_tile<YDataType>(acc_0);
store_tile(bridge_sst_win, y_pre);
block_sync_lds();
auto uk_1 = Policy::template GetUK_1<Problem>();
uk_1(d_res,
d_coords,
o_res,
o_coords,
o_flags,
smem,
kargs.hidden_size,
kargs.hidden_size, // total n number
w_scale,
BlockShape::Block_Kr0 * BlockShape::Block_W0,
kargs.stride_token);
BlockShape::Block_Nr1 * kr_1 * BlockShape::Block_W1, // along N
BlockShape::Block_N1); // along N
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment