Commit f1cdecfb authored by Jing Zhang's avatar Jing Zhang
Browse files

fix

parent 426abafe
......@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = rand() % 16 + 1;
int group_count = 4;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmTransposeDesc> gemm_descs;
......@@ -89,11 +89,13 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++)
{
int M = 1024;
int N = 1024;
int K = 1024;
int B = 16;
int S = 64;
int nH = 16;
int hD = 64;
gemm_descs.push_back({M, N, K, K, K, N});
gemm_descs.push_back(
{B * S, nH * hD, nH * hD, nH * hD, nH * hD, B, S, nH, hD, S * nH * hD, S * hD, hD, 1});
}
auto f_host_tensor_descriptor =
......@@ -110,6 +112,19 @@ int main(int argc, char* argv[])
}
};
auto f_host_c_tensor_descriptor = [](std::size_t M0,
std::size_t M1,
std::size_t N0,
std::size_t N1,
std::size_t StrideM0,
std::size_t StrideM1,
std::size_t StrideN0,
std::size_t StrideN1) {
return HostTensorDescriptor(
std::vector<std::size_t>({M0, M1, N0, N1}),
std::vector<std::size_t>({StrideM0, StrideM1, StrideN0, StrideN1}));
};
std::vector<Tensor<ADataType>> a_tensors;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
......@@ -136,10 +151,24 @@ int main(int argc, char* argv[])
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
c_host_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0,
gemm_descs[i].M1,
gemm_descs[i].N0,
gemm_descs[i].N1,
gemm_descs[i].StrideM0,
gemm_descs[i].StrideM1,
gemm_descs[i].StrideN0,
gemm_descs[i].StrideN1)));
c_device_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0,
gemm_descs[i].M1,
gemm_descs[i].N0,
gemm_descs[i].N1,
gemm_descs[i].StrideM0,
gemm_descs[i].StrideM1,
gemm_descs[i].StrideN0,
gemm_descs[i].StrideN1)));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
......
......@@ -16,11 +16,11 @@ struct GemmDesc
struct GemmTransposeDesc
{
ck::index_t M, N, K;;
ck::index_t StrideA, StrideB, StrideC;
ck::index_t M, N, K;
ck::index_t StrideA, StrideB;
ck::index_t B, S, NumHead, HeadDim;
std::vector<ck::index_t> transpose;
ck::index_t M0, M1, N0, N1;
ck::index_t StrideM0, StrideM1, StrideN0, StrideN1;
};
template <typename AElementwiseOperation,
......@@ -51,7 +51,8 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemmTranspose : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmTransposeDesc>& gemm_transpose_desc,
......@@ -66,10 +67,10 @@ struct DeviceGroupedGemmTranspose : public BaseOperator
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGroupedGemmTransposePtr = std::unique_ptr<
DeviceGroupedGemmTranspose<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
using DeviceGroupedGemmTransposePtr =
std::unique_ptr<DeviceGroupedGemmTranspose<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
......
......@@ -29,7 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_transpose_xdlops_v2r3(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
kernel_grouped_gemm_transpose_xdlops_v2r3(
const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
......@@ -111,8 +112,9 @@ template <typename ADataType,
ck::index_t CThreadTransferDstScalarPerVector,
ck::index_t NumPrefetch = 1,
ck::index_t MaxGroupCount = 10>
struct DeviceGroupedGemmTransposeXdl
: public DeviceGroupedGemmTranspose<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -198,19 +200,29 @@ struct DeviceGroupedGemmTransposeXdl
}
}
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
static auto MakeCGridDescriptor_M_N(index_t M0,
index_t M1,
index_t N0,
index_t N1,
index_t StrideM0,
index_t StrideM1,
index_t StrideN0,
index_t StrideN1)
{
const auto c_grid_desc_m_n = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, CLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
}
const auto c_grid_desc_m0_m1_n0_n1 = make_naive_tensor_descriptor(
make_tuple(M0, M1, N0, N1), make_tuple(StrideM0, StrideM1, StrideN0, StrideN1));
return transform_tensor_descriptor(c_grid_desc_m0_m1_n0_n1,
make_tuple(make_merge_transform(make_tuple(M0, M1)),
make_merge_transform(make_tuple(N0, N1))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const index_t M = M0 * M1;
const index_t N = N0 * N1;
if constexpr(GemmSpec == GemmSpecialization::MNPadding)
{
const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
......@@ -235,7 +247,7 @@ struct DeviceGroupedGemmTransposeXdl
using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1, 1, 1, 1));
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
......@@ -384,14 +396,21 @@ struct DeviceGroupedGemmTransposeXdl
const index_t StrideA = gemm_transpose_desc[i].StrideA;
const index_t StrideB = gemm_transpose_desc[i].StrideB;
const index_t StrideC = gemm_transpose_desc[i].StrideC;
const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmTransposeXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1_ =
DeviceGroupedGemmTransposeXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmTransposeXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
DeviceGroupedGemmTransposeXdl::MakeCGridDescriptor_M_N(
gemm_transpose_desc[i].M0,
gemm_transpose_desc[i].M1,
gemm_transpose_desc[i].N0,
gemm_transpose_desc[i].N1,
gemm_transpose_desc[i].StrideM0,
gemm_transpose_desc[i].StrideM1,
gemm_transpose_desc[i].StrideN0,
gemm_transpose_desc[i].StrideN1);
const index_t grid_size_grp =
typename GroupedGemmBlock2CTileMap::UnderlyingBlock2CTileMap(
......@@ -501,7 +520,8 @@ struct DeviceGroupedGemmTransposeXdl
{
const auto kernel =
kernel_grouped_gemm_transpose_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
ADataType, // TODO: distiguish A/B
// datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
......@@ -525,7 +545,8 @@ struct DeviceGroupedGemmTransposeXdl
{
const auto kernel =
kernel_grouped_gemm_transpose_xdlops_v2r3<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
ADataType, // TODO: distiguish A/B
// datatype
CDataType,
GemmDescKernelArg,
AElementwiseOperation,
......@@ -585,13 +606,15 @@ struct DeviceGroupedGemmTransposeXdl
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, gemm_transpose_desc, 1, 1, a_element_op, b_element_op, c_element_op};
return Argument{
p_a, p_b, p_c, gemm_transpose_desc, 1, 1, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*>& p_a,
std::unique_ptr<BaseArgument>
MakeArgumentPointer(std::vector<const void*>& p_a,
std::vector<const void*>& p_b,
std::vector<void*>& p_c,
std::vector<GemmTransposeDesc>& gemm_transpose_desc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment