"...composable_kernel_rocm.git" did not exist on "11e4082dd8459f3a0e69f7d164ef64eb7ebfa7fa"
Commit f2cc8405 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments

parent 35b07efb
......@@ -78,7 +78,7 @@ int main(int argc, char* argv[])
exit(0);
}
int group_count = 4;
int group_count = rand() % 16 + 1;
// GEMM shape
std::vector<ck::tensor_operation::device::GemmTransposeDesc> gemm_descs;
......@@ -89,66 +89,62 @@ int main(int argc, char* argv[])
for(int i = 0; i < group_count; i++)
{
int B = 16;
int S = 64;
int NumHead = 16;
int HeadDim = 64;
const int M0 = rand() % 4 + 1;
const int M1 = 256;
const int N0 = rand() % 4 + 1;
const int N1 = 256;
int M0 = B;
int M1 = S;
int N0 = NumHead;
int N1 = HeadDim;
const int M = M0 * N1;
const int N = N0 * N1;
int M = M0 * N1;
int N = N0 * N1;
int K = NumHead * HeadDim;
const int K = 128 * (rand() % 4 + 1);
int StrideA = K;
int StrideB = K;
const int stride_A = K;
const int stride_B = K;
if(i % 2 == 0)
{
int StrideM0 = S * NumHead * HeadDim;
int StrideM1 = 1;
int StrideN0 = S * HeadDim;
int StrideN1 = S;
// output layout [M0, N0, M1, N1]
const int stride_M0 = N1 * M1 * N0;
const int stride_M1 = N1;
const int stride_N0 = N1 * M1;
const int stride_N1 = 1;
gemm_descs.push_back({M,
N,
K,
StrideA,
StrideB,
stride_A,
stride_B,
M0,
M1,
N0,
N1,
StrideM0,
StrideM1,
StrideN0,
StrideN1});
stride_M0,
stride_M1,
stride_N0,
stride_N1});
}
else
{
int StrideM0 = S * NumHead * HeadDim;
int StrideM1 = HeadDim;
int StrideN0 = S * HeadDim;
int StrideN1 = 1;
// output layout [M0, N0, N1, M1]
int stride_M0 = N1 * N1 * N0;
int stride_M1 = 1;
int stride_N0 = M1 * N1;
int stride_N1 = M1;
gemm_descs.push_back({M,
N,
K,
StrideA,
StrideB,
stride_A,
stride_B,
M0,
M1,
N0,
N1,
StrideM0,
StrideM1,
StrideN0,
StrideN1});
stride_M0,
stride_M1,
stride_N0,
stride_N1});
}
}
......@@ -202,33 +198,33 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0,
gemm_descs[i].M1,
gemm_descs[i].N0,
gemm_descs[i].N1,
gemm_descs[i].StrideM0,
gemm_descs[i].StrideM1,
gemm_descs[i].StrideN0,
gemm_descs[i].StrideN1)));
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0_,
gemm_descs[i].M1_,
gemm_descs[i].N0_,
gemm_descs[i].N1_,
gemm_descs[i].stride_M0_,
gemm_descs[i].stride_M1_,
gemm_descs[i].stride_N0_,
gemm_descs[i].stride_N1_)));
c_device_tensors.push_back(
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0,
gemm_descs[i].M1,
gemm_descs[i].N0,
gemm_descs[i].N1,
gemm_descs[i].StrideM0,
gemm_descs[i].StrideM1,
gemm_descs[i].StrideN0,
gemm_descs[i].StrideN1)));
Tensor<CDataType>(f_host_c_tensor_descriptor(gemm_descs[i].M0_,
gemm_descs[i].M1_,
gemm_descs[i].N0_,
gemm_descs[i].N1_,
gemm_descs[i].stride_M0_,
gemm_descs[i].stride_M1_,
gemm_descs[i].stride_N0_,
gemm_descs[i].stride_N1_)));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
......
......@@ -133,19 +133,19 @@ int main(int argc, char* argv[])
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc
<< " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc
<< std::endl;
flop += std::size_t(2) * gemm_descs[i].M * gemm_descs[i].K * gemm_descs[i].N;
flop += std::size_t(2) * gemm_descs[i].M_ * gemm_descs[i].K_ * gemm_descs[i].N_;
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
......
......@@ -10,17 +10,17 @@ namespace device {
struct GemmDesc
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_, stride_C_;
};
struct GemmTransposeDesc
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB;
ck::index_t M_, N_, K_;
ck::index_t stride_A_, stride_B_;
ck::index_t M0, M1, N0, N1;
ck::index_t StrideM0, StrideM1, StrideN0, StrideN1;
ck::index_t M0_, M1_, N0_, N1_;
ck::index_t stride_M0_, stride_M1_, stride_N0_, stride_N1_;
};
template <typename AElementwiseOperation,
......
#ifndef DEVICE_GROUPED_GEMM_XDL_HPP
#define DEVICE_GROUPED_GEMM_XDL_HPP
#ifndef DEVICE_GROUPED_GEMM_TRANSPOSE_XDL_HPP
#define DEVICE_GROUPED_GEMM_TRANSPOSE_XDL_HPP
#include <iostream>
#include <sstream>
......@@ -389,12 +389,18 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
for(std::size_t i = 0; i < gemm_transpose_desc.size(); i++)
{
const index_t M = gemm_transpose_desc[i].M;
const index_t N = gemm_transpose_desc[i].N;
const index_t K = gemm_transpose_desc[i].K;
const index_t M = gemm_transpose_desc[i].M_;
const index_t N = gemm_transpose_desc[i].N_;
const index_t K = gemm_transpose_desc[i].K_;
const index_t StrideA = gemm_transpose_desc[i].StrideA;
const index_t StrideB = gemm_transpose_desc[i].StrideB;
const index_t StrideA = gemm_transpose_desc[i].stride_A_;
const index_t StrideB = gemm_transpose_desc[i].stride_B_;
if(!(M == gemm_transpose_desc[i].M0_ * gemm_transpose_desc[i].M1_ &&
N == gemm_transpose_desc[i].N0_ * gemm_transpose_desc[i].N1_))
{
throw std::runtime_error("wrong! M != M0 * M1 or N != N0 * N1");
}
const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmTransposeXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
......@@ -402,14 +408,14 @@ struct DeviceGroupedGemmTransposeXdl : public DeviceGroupedGemmTranspose<AElemen
DeviceGroupedGemmTransposeXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmTransposeXdl::MakeCGridDescriptor_M_N(
gemm_transpose_desc[i].M0,
gemm_transpose_desc[i].M1,
gemm_transpose_desc[i].N0,
gemm_transpose_desc[i].N1,
gemm_transpose_desc[i].StrideM0,
gemm_transpose_desc[i].StrideM1,
gemm_transpose_desc[i].StrideN0,
gemm_transpose_desc[i].StrideN1);
gemm_transpose_desc[i].M0_,
gemm_transpose_desc[i].M1_,
gemm_transpose_desc[i].N0_,
gemm_transpose_desc[i].N1_,
gemm_transpose_desc[i].stride_M0_,
gemm_transpose_desc[i].stride_M1_,
gemm_transpose_desc[i].stride_N0_,
gemm_transpose_desc[i].stride_N1_);
const index_t grid_size_grp =
GroupedGemmBlock2CTileMap(c_grid_desc_m_n_, M01, N01, 0)
......
......@@ -377,13 +377,13 @@ struct DeviceGroupedGemmXdl
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
const index_t M = gemm_descs[i].M;
const index_t N = gemm_descs[i].N;
const index_t K = gemm_descs[i].K;
const index_t M = gemm_descs[i].M_;
const index_t N = gemm_descs[i].N_;
const index_t K = gemm_descs[i].K_;
const index_t StrideA = gemm_descs[i].StrideA;
const index_t StrideB = gemm_descs[i].StrideB;
const index_t StrideC = gemm_descs[i].StrideC;
const index_t StrideA = gemm_descs[i].stride_A_;
const index_t StrideB = gemm_descs[i].stride_B_;
const index_t StrideC = gemm_descs[i].stride_C_;
const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
......
......@@ -63,8 +63,8 @@ struct ReferenceGemmTranspose : public device::BaseOperator
float v_a;
float v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n)));
arg.a_element_op_(v_a, ck::type_convert<const float>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, ck::type_convert<const float>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b;
}
......
......@@ -107,13 +107,13 @@ bool TestGroupedGemm(DeviceGroupedGemmPtr_& groupedGemmPtr)
for(std::size_t i = 0; i < gemm_descs.size(); i++)
{
a_tensors.emplace_back(Tensor<ADataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].K, gemm_descs[i].StrideA, ALayout{})));
gemm_descs[i].M_, gemm_descs[i].K_, gemm_descs[i].stride_A_, ALayout{})));
b_tensors.emplace_back(Tensor<BDataType>(f_host_tensor_descriptor(
gemm_descs[i].K, gemm_descs[i].N, gemm_descs[i].StrideB, BLayout{})));
gemm_descs[i].K_, gemm_descs[i].N_, gemm_descs[i].stride_B_, BLayout{})));
c_host_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
c_device_tensors.emplace_back(Tensor<CDataType>(f_host_tensor_descriptor(
gemm_descs[i].M, gemm_descs[i].N, gemm_descs[i].StrideC, CLayout{})));
gemm_descs[i].M_, gemm_descs[i].N_, gemm_descs[i].stride_C_, CLayout{})));
a_tensors[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_tensors[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment