Commit db53dba4 authored by coderfeli's avatar coderfeli
Browse files

hotfix:gemm1 use real tokens and gemm2 ok

parent 58db931e
......@@ -35,7 +35,7 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using B0DataType = F8;
using EDataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
......@@ -133,6 +133,8 @@ using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t CShuffleMXDLPerWave = MPerBlock / 32;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
......@@ -156,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
256, MPerBlock, 128, KPerBlock,
BLOCKSIZE, MPerBlock, NPerBlock, KPerBlock,
// ak1, bk1
AK1, BK1,
// mn_perxdl
......@@ -196,10 +198,10 @@ int main(int argc, char* argv[])
ck::index_t valid_tile_num = 8;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64;
ck::index_t tokens = 64;
ck::index_t topk = 2;
ck::index_t tokens = batch * topk;
// ck::index_t tokens = batch * topk;
if(argc == 1)
{
......@@ -225,7 +227,6 @@ int main(int argc, char* argv[])
ck::index_t StrideA = K;
ck::index_t StrideB = K;
// ck::index_t StrideD = 0;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0};
......@@ -241,14 +242,14 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i;
}
int token_per_tile = tokens / valid_tile_num;
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) {
int tile_off = i % valid_size;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
......@@ -258,13 +259,13 @@ int main(int argc, char* argv[])
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1}));
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({batch, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
......@@ -293,8 +294,6 @@ int main(int argc, char* argv[])
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
......@@ -304,13 +303,14 @@ int main(int argc, char* argv[])
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_t_k.savetxt("a.txt");
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
// e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -358,9 +358,9 @@ int main(int argc, char* argv[])
if (time_kernel) {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * sorted_size * N * K;
std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t num_btype =
sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N;
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -376,7 +376,7 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_k_n({batch, topk, N}, {topk * N, N, 1});
Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMoeGemm<A0DataType,
B0DataType,
......
......@@ -67,7 +67,9 @@ struct MulABScaleExpertWeight
const float& d1,
const float& d2) const
{
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
// e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1);
}
// for reference
template <>
......@@ -78,7 +80,8 @@ struct MulABScaleExpertWeight
const float& d1,
const float& d2) const
{
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
(void) d2;
e = ck::type_convert<EDataType>(c * d0 * d1);
}
};
......@@ -187,10 +190,12 @@ int main(int argc, char* argv[])
ck::index_t experts = 8;
ck::index_t sorted_tile_num = 9;
ck::index_t valid_tile_num = 8;
ck::index_t sorted_tile_size = MPerBlock;
ck::index_t sorted_size = sorted_tile_num * sorted_tile_size;
ck::index_t valid_size = valid_tile_num * sorted_tile_size;
ck::index_t tokens = 64;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t batch = 64;
ck::index_t topk = 2;
ck::index_t tokens = batch;
if(argc == 1)
{
......@@ -231,77 +236,83 @@ int main(int argc, char* argv[])
for (int i = 0; i < sorted_tile_num; i++) {
expert_ids.mData[i] = i;
}
int token_per_tile = tokens / sorted_tile_num;
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for (int i = 0; i < sorted_size; i++) {
int tile_off = i % sorted_tile_size;
int tile_off = i % valid_size;
if(tile_off < token_per_tile)
sorted_token_ids.mData[i] = tokenid++;
{
sorted_token_ids.mData[i] = (tokenid % batch) | ((tokenid / batch) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
Tensor<A0DataType> a0_m_k(HostTensorDescriptor({sorted_size, K}, {K, 1}));
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({batch, K}, {K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, N, K}, {N*K, K, 1}));
Tensor<D0DataType> d0_m_n(HostTensorDescriptor({sorted_size, N}, {StrideDs[0], 0}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({batch, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "d1_e_n: " << d1_e_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d0_t_n: " << d0_t_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d0_t_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_e_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
d2_e_n.GenerateTensorValue(GeneratorTensor_2<D2DataType>{-2, 2});
break;
case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
d0_m_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d0_t_n.GenerateTensorValue(GeneratorTensor_1<D0DataType>{});
d1_e_n.GenerateTensorValue(GeneratorTensor_1<D1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d0_t_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_e_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_t_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
a0_m_k.savetxt("a.txt");
a0_t_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
d0_m_n.savetxt("d0_m_n.txt", "int");
d0_t_n.savetxt("d0_t_n.txt", "int");
d1_e_n.savetxt("d1_e_n.txt", "int");
d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
d0_device_buf.ToDevice(d0_t_n.mData.data());
d1_device_buf.ToDevice(d1_e_n.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
......@@ -332,6 +343,7 @@ int main(int argc, char* argv[])
d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
......@@ -354,9 +366,9 @@ int main(int argc, char* argv[])
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * sorted_size * N * K;
std::size_t flop = std::size_t(2) * valid_size * N * K;
std::size_t num_btype =
sizeof(A0DataType) * sorted_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * sorted_size * N;
sizeof(A0DataType) * valid_size * K + sizeof(B0DataType) * K * N * experts + sizeof(EDataType) * valid_size * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
......@@ -387,12 +399,12 @@ int main(int argc, char* argv[])
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(
sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a0_m_k, b0_e_n_k, d0_m_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
......@@ -402,7 +414,6 @@ int main(int argc, char* argv[])
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
e_t_n_device_result.savetxt("out.txt");
e_t_n_host_result.savetxt("ref.txt");
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
......
......@@ -638,7 +638,6 @@ struct GridwiseMoeGemmGather
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_},
......@@ -663,7 +662,6 @@ struct GridwiseMoeGemmGather
const index_t * p_sorted_token_ids;
const index_t * p_sorted_expert_ids;
const index_t * p_max_token_id;
const ADataType* p_a_grid;
const BDataType* p_b_grid;
DsGridPointer p_ds_grid;
......@@ -1146,7 +1144,7 @@ struct GridwiseMoeGemmGather
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
problem.NumTokens * problem.TopK, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
// printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(),
// problem.MBlock, problem.NBlock, MPerBlock, NPerBlock);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
......@@ -1165,7 +1163,6 @@ struct GridwiseMoeGemmGather
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
// static_assert(MLoadRepeats == 1, "only support 1 line per thread now!");
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
......@@ -1177,8 +1174,6 @@ struct GridwiseMoeGemmGather
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
// const index_t m_block_data_idx_on_grid =
// __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack
......
......@@ -9,7 +9,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_mod8.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
......@@ -536,6 +536,7 @@ struct GridwiseMoeGemmScatter
struct Problem
{
__host__ __device__ Problem(index_t NumTokens_,
index_t TopK_,
index_t M_,
index_t N_,
index_t K_,
......@@ -546,6 +547,7 @@ struct GridwiseMoeGemmScatter
index_t KBatch_)
:
NumTokens{NumTokens_},
TopK{TopK_},
M{M_},
N{N_},
K{K_},
......@@ -571,6 +573,7 @@ struct GridwiseMoeGemmScatter
{
std::cout << "problem {"
<< "NumTokens:" << NumTokens << ", "
<< "TopK:" << TopK << ", "
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
......@@ -588,6 +591,7 @@ struct GridwiseMoeGemmScatter
}
index_t NumTokens;
index_t TopK;
index_t M;
index_t N;
index_t K;
......@@ -621,6 +625,7 @@ struct GridwiseMoeGemmScatter
std::array<const void*, NumDTensor> p_ds_grid_,
CDataType* p_c_grid_,
index_t NumTokens_,
index_t TopK_,
index_t M_,
index_t N_,
index_t K_,
......@@ -632,8 +637,7 @@ struct GridwiseMoeGemmScatter
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{NumTokens_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
: Problem{NumTokens_, TopK_, M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
p_sorted_token_ids{p_sorted_token_ids_},
p_sorted_expert_ids{p_sorted_expert_ids_},
p_max_token_id{p_max_token_id_},
......@@ -1135,7 +1139,7 @@ struct GridwiseMoeGemmScatter
{
ignore = b_element_op;
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
problem.NumTokens, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
......@@ -1151,12 +1155,25 @@ struct GridwiseMoeGemmScatter
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
constexpr auto AKThreads = AK0Threads * AK1Threads;
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(m_block_data_idx_on_grid >= max_token_id || token0 >= problem.NumTokens)
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t token_offset = (token_pos + m0 < max_token_id) ?
(p_sorted_token_ids[token_pos + m0] & 0xffffff) : problem.NumTokens;
gather_offsets(m0) = token_offset * problem.K;
// printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0));
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
......@@ -1177,7 +1194,7 @@ struct GridwiseMoeGemmScatter
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
ThreadGroupTensorSliceTransfer_v4r1_mod8<ThisThreadBlock,
AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set,
......@@ -1198,13 +1215,15 @@ struct GridwiseMoeGemmScatter
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
1,
BlockwiseGemmPipe::GlobalBufferNum>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
make_multi_index(0, 0, 0),
a_element_op,
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
ck::tensor_operation::element_wise::PassThrough{},
gather_offsets);
// Thread-wise copy
// K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
......@@ -1384,11 +1403,11 @@ struct GridwiseMoeGemmScatter
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
const DDataType *ptr_ = p_ds_grid[i];
// hack logic here to support different kind of strides. todo fix it.
// ascale M, 1; bscale E, N, 1, move ptr to E
// ascale t, 1; bscale E, N, 1, move ptr to E
if (i.value == 1)
{
ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1);
// if ( threadIdx.x ==0)
// if ( threadIdx.x % 16 ==0)
// printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]);
}
return make_dynamic_buffer<AddressSpaceEnum::Global>(
......@@ -1428,7 +1447,6 @@ struct GridwiseMoeGemmScatter
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
......@@ -1436,10 +1454,12 @@ struct GridwiseMoeGemmScatter
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights = p_ds_grid[I2];
const float *p_sorted_weights_2 = p_ds_grid[I2];
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, EMRepeats, 1>{}([&](auto m0) {
scatter_offsets(m0) = (p_sorted_token_ids[c_token_pos + m0] & 0xffffff) * problem.N;
scatter_weights(m0) = p_sorted_weights[c_token_pos + m0];
scatter_weights(m0) = p_sorted_weights_2[c_token_pos + m0]
* p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
......
......@@ -35,7 +35,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
......@@ -48,7 +48,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
expert_ids_{expert_ids},
max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size},
a_m_k_{a_m_k},
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
d0_{d0},
d1_{d1},
......@@ -64,7 +64,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_m_k_;
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
const Tensor<D0DataType>& d0_;
const Tensor<D1DataType>& d1_;
......@@ -85,7 +85,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
{
arg.c_t_n_.SetZero();
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1];
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
......@@ -101,11 +101,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
if constexpr(is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::ConvertBF16RTN>)
{
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_m_k_(m, k));
ck::tensor_operation::element_wise::PassThrough{}(v_a, arg.a_t_k_(t, k));
}
else
{
arg.a_element_op_(v_a, arg.a_m_k_(m, k));
arg.a_element_op_(v_a, arg.a_t_k_(t, k));
}
// same for B matrix
if constexpr(is_same_v<BElementwiseOperation,
......@@ -157,7 +157,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_m_k,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D0DataType>& d0,
const Tensor<D1DataType>& d1,
......@@ -167,7 +167,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_m_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
return Argument{sorted_token_ids, expert_ids, max_token_id, sorted_tile_size, a_t_k, b_e_n_k, d0, d1, d2, c_t_n, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment