Commit f2cc8405 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent 35b07efb
...@@ -78,7 +78,7 @@ int main(int argc, char* argv[]) ...@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0); exit(0);
} }
int group_count = 4; int group_count = rand() % 16 + 1;
// GEMM shape // GEMM shape
std::vector<ck::tensor_operation::device::GemmTransposeDesc> gemm_descs; std::vector<ck::tensor_operation::device::GemmTransposeDesc> gemm_descs;
...@@ -89,66 +89,62 @@ int main(int argc, char* argv[]) ...@@ -89,66 +89,62 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
int B = 16; const int M0 = rand() % 4 + 1;
int S = 64; const int M1 = 256;
int NumHead = 16; const int N0 = rand() % 4 + 1;
int HeadDim = 64; const int N1 = 256;
int M0 = B; const int M = M0 * N1;
int M1 = S; const int N = N0 * N1;
int N0 = NumHead;
int N1 = HeadDim;
int M = M0 * N1; const int K = 128 * (rand() % 4 + 1);
int N = N0 * N1;
int K = NumHead * HeadDim;
int StrideA = K; const int stride_A = K;
int StrideB = K; const int stride_B = K;
if(i % 2 == 0) if(i % 2 == 0)
{ {
// output layout [M0, N0, M1, N1]
int StrideM0 = S * NumHead * HeadDim; const int stride_M0 = N1 * M1 * N0;
int StrideM1 = 1; const int stride_M1 = N1;
int StrideN0 = S * HeadDim; const int stride_N0 = N1 * M1;
int StrideN1 = S; const int stride_N1 = 1;
gemm_descs.push_back({M, gemm_descs.push_back({M,
N, N,
K, K,
StrideA, stride_A,
StrideB, stride_B,
M0, M0,
M1, M1,
N0, N0,
N1, N1,
StrideM0, stride_M0,
StrideM1, stride_M1,
StrideN0, stride_N0,
StrideN1}); stride_N1});
} }
else else
{ {
// output layout [M0, N0, N1, M1]
int StrideM0 = S * NumHead * HeadDim; int stride_M0 = N1 * N1 * N0;
int StrideM1 = HeadDim; int stride_M1 = 1;
int StrideN0 = S * HeadDim; int stride_N0 = M1 * N1;
int StrideN1 = 1; int stride_N1 = M1;
gemm_descs.push_back({M, gemm_descs.push_back({M,
N, N,
K, K,
StrideA, stride_A,
StrideB, stride_B,
M0, M0,
M1, M1,
N0, N0,
N1, N1,
StrideM0, stride_M0,
StrideM1, stride_M1,
StrideN0, stride_N0,
StrideN1}); stride_N1});
} }
} }
...@@ -202,33 +198,33 @@ int main(int argc, char* argv[]) ...@@ -202,33 +198,33 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{}))); gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back( c_host_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0, Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0_,
gemm_descs[i].M1, gemm_descs[i].M1_,
gemm_descs[i].N0, gemm_descs[i].N0_,
gemm_descs[i].N1, gemm_descs[i].N1_,
gemm_descs[i].StrideM0, gemm_descs[i].stride_M0_,
gemm_descs[i].StrideM1, gemm_descs[i].stride_M1_,
gemm_descs[i].StrideN0, gemm_descs[i].stride_N0_,
gemm_descs[i].StrideN1))); gemm_descs[i].stride_N1_)));
c_device_tensors.push_back( c_device_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0, Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0_,
gemm_descs[i].M1, gemm_descs[i].M1_,
gemm_descs[i].N0, gemm_descs[i].N0_,
gemm_descs[i].N1, gemm_descs[i].N1_,
gemm_descs[i].StrideM0, gemm_descs[i].stride_M0_,
gemm_descs[i].StrideM1, gemm_descs[i].stride_M1_,
gemm_descs[i].StrideN0, gemm_descs[i].stride_N0_,
gemm_descs[i].StrideN1))); gemm_descs[i].stride_N1_)));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N; flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
......
...@@ -133,19 +133,19 @@ int main(int argc, char* argv[]) ...@@ -133,19 +133,19 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{}))); gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl; << std::endl;
flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N; flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize(); sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
......
...@@ -10,17 +10,17 @@ namespace device { ...@@ -10,17 +10,17 @@ namespace device {
struct GemmDesc struct GemmDesc
{ {
ck::index_t M, N, K; ck::index_t M_, N_, K_;
ck::index_t StrideA, StrideB, StrideC; ck::index_t stride_A_, stride_B_, stride_C_;
}; };
struct GemmTransposeDesc struct GemmTransposeDesc
{ {
ck::index_t M, N, K; ck::index_t M_, N_, K_;
ck::index_t StrideA, StrideB; ck::index_t stride_A_, stride_B_;
ck::index_t M0, M1, N0, N1; ck::index_t M0_, M1_, N0_, N1_;
ck::index_t StrideM0, StrideM1, StrideN0, StrideN1; ck::index_t stride_M0_, stride_M1_, stride_N0_, stride_N1_;
}; };
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
......
#ifndef DEVICE_GROUPED_GEMM_XDL_HPP #ifndef DEVICE_GROUPED_GEMM_TRANSPOSE_XDL_HPP
#define DEVICE_GROUPED_GEMM_XDL_HPP #define DEVICE_GROUPED_GEMM_TRANSPOSE_XDL_HPP
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
...@@ -389,12 +389,18 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen ...@@ -389,12 +389,18 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
for(std::size_t i = 0; i < gemm_transpose_desc.size(); i++) for(std::size_t i = 0; i < gemm_transpose_desc.size(); i++)
{ {
const index_t M = gemm_transpose_desc[i].M; const index_t M = gemm_transpose_desc[i].M_;
const index_t N = gemm_transpose_desc[i].N; const index_t N = gemm_transpose_desc[i].N_;
const index_t K = gemm_transpose_desc[i].K; const index_t K = gemm_transpose_desc[i].K_;
const index_t StrideA = gemm_transpose_desc[i].StrideA; const index_t StrideA = gemm_transpose_desc[i].stride_A_;
const index_t StrideB = gemm_transpose_desc[i].StrideB; const index_t StrideB = gemm_transpose_desc[i].stride_B_;
if(!(M == gemm_transpose_desc[i].M0_ * gemm_transpose_desc[i].M1_ &&
N == gemm_transpose_desc[i].N0_ * gemm_transpose_desc[i].N1_))
{
throw std::runtime_error("wrong! M != M0 * M1 or N != N0 * N1");
}
const auto a_grid_desc_k0_m_k1_ = const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmTransposeXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); DeviceGroupedGemmTransposeXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
...@@ -402,14 +408,14 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen ...@@ -402,14 +408,14 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
DeviceGroupedGemmTransposeXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); DeviceGroupedGemmTransposeXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ = const auto c_grid_desc_m_n_ =
DeviceGroupedGemmTransposeXdl::MakeCGridDescriptor_M_N( DeviceGroupedGemmTransposeXdl::MakeCGridDescriptor_M_N(
gemm_transpose_desc[i].M0, gemm_transpose_desc[i].M0_,
gemm_transpose_desc[i].M1, gemm_transpose_desc[i].M1_,
gemm_transpose_desc[i].N0, gemm_transpose_desc[i].N0_,
gemm_transpose_desc[i].N1, gemm_transpose_desc[i].N1_,
gemm_transpose_desc[i].StrideM0, gemm_transpose_desc[i].stride_M0_,
gemm_transpose_desc[i].StrideM1, gemm_transpose_desc[i].stride_M1_,
gemm_transpose_desc[i].StrideN0, gemm_transpose_desc[i].stride_N0_,
gemm_transpose_desc[i].StrideN1); gemm_transpose_desc[i].stride_N1_);
const index_t grid_size_grp = const index_t grid_size_grp =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0) GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0)
......
...@@ -377,13 +377,13 @@ struct DeviceGroupedGemmXdl ...@@ -377,13 +377,13 @@ struct DeviceGroupedGemmXdl
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
const index_t M = gemm_descs[i].M; const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N; const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K; const index_t K = gemm_descs[i].K_;
const index_t StrideA = gemm_descs[i].StrideA; const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].StrideB; const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].StrideC; const index_t StrideC = gemm_descs[i].stride_C_;
const auto a_grid_desc_k0_m_k1_ = const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
......
...@@ -63,8 +63,8 @@ struct ReferenceGemmTranspose : public device::BaseOperator ...@@ -63,8 +63,8 @@ struct ReferenceGemmTranspose : public device::BaseOperator
float v_a; float v_a;
float v_b; float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, ck::type_convert<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, ck::type_convert<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
......
...@@ -107,13 +107,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr) ...@@ -107,13 +107,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for(std::size_t i = 0; i < gemm_descs.size(); i++) for(std::size_t i = 0; i < gemm_descs.size(); i++)
{ {
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor( a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{}))); gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor( b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{}))); gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor( c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{}))); gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment