Commit 65d67fb7 authored by Jing Zhang's avatar Jing Zhang
Browse files

add ptr to GemmDesc

parent f9b740b5
......@@ -79,21 +79,21 @@ int main(int argc, char* argv[])
int group_count = 4;
// GEMM shape
std::vector<ck::gemm_desc> gemm_shapes;
std::vector<ck::GemmShape> gemm_shapes;
int A_size = 0, B_size = 0, C_size = 0;
for(int i = 0; i < group_count; i++)
{
int M = 3840;
int N = 1024;
int K = 4096;
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, A_size, B_size, C_size});
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
A_size += M * K;
B_size += N * K;
C_size += M * N;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
}
auto f_host_tensor_descriptor =
......@@ -163,12 +163,27 @@ int main(int argc, char* argv[])
std::vector<ADataType> a_tensors_data, b_tensors_data, c_tensors_data;
A_size = 0;
B_size = 0;
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_data.insert(
a_tensors_data.end(), a_tensors[i].mData.begin(), a_tensors[i].mData.end());
b_tensors_data.insert(
b_tensors_data.end(), b_tensors[i].mData.begin(), b_tensors[i].mData.end());
gemm_shapes[i].p_a =
static_cast<ADataType*>(a_tensors_device_buf.GetDeviceBuffer()) + A_size;
gemm_shapes[i].p_b =
static_cast<BDataType*>(b_tensors_device_buf.GetDeviceBuffer()) + B_size;
gemm_shapes[i].p_c =
static_cast<CDataType*>(c_tensors_device_buf.GetDeviceBuffer()) + C_size;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
}
a_tensors_device_buf.ToDevice(a_tensors_data.data());
......@@ -179,16 +194,9 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(static_cast<ADataType*>(a_tensors_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_tensors_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_tensors_device_buf.GetDeviceBuffer()),
gemm_shapes,
a_element_op,
b_element_op,
c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......@@ -210,11 +218,14 @@ int main(int argc, char* argv[])
c_tensors_device_buf.FromDevice(c_tensors_data.data());
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
memcpy(c_device_tensors[i].mData.data(),
c_tensors_data.data() + gemm_shapes[i].OffsetC,
c_tensors_data.data() + C_size,
c_device_tensors[i].mData.size() * sizeof(CDataType));
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
}
if(do_verification)
......
......@@ -177,11 +177,11 @@ enum ActivTypeEnum_t
using index_t = int32_t;
using long_index_t = int64_t;
struct gemm_desc
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
ck::index_t OffsetA, OffsetB, OffsetC;
void *p_a, *p_b, *p_c;
};
} // namespace ck
......
......@@ -64,10 +64,7 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<gemm_desc> gemm_shapes,
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -233,99 +233,84 @@ struct DeviceGroupedGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
ck::index_t OffsetA, OffsetB, OffsetC;
const ADataType* a_ptr;
const BDataType* b_ptr;
CDataType* c_ptr;
ck::index_t BlockStart, BlockEnd;
};
// Argument
struct Argument : public BaseArgument
{
Argument(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
std::vector<gemm_desc> gemm_shapes,
Argument(std::vector<GemmShape> gemm_shapes,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid},
gemm_shapes_{gemm_shapes},
M01_{M01},
: M01_{M01},
N01_{N01},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
grid_size = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < gemm_shapes_.size())
{
const index_t M = gemm_shapes_[i].M;
const index_t N = gemm_shapes_[i].N;
const index_t K = gemm_shapes_[i].K;
const index_t StrideA = gemm_shapes_[i].StrideA;
const index_t StrideB = gemm_shapes_[i].StrideB;
const index_t StrideC = gemm_shapes_[i].StrideC;
gemm_desc_(i).a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
gemm_desc_(i).b_grid_desc_k0_n_k1_ =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
gemm_desc_(i).c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp =
GridwiseGemm::CalculateGridSize(gemm_desc_[i].c_grid_desc_m_n_);
gemm_desc_(i).BlockStart = grid_size;
gemm_desc_(i).BlockEnd = grid_size + grid_size_grp;
grid_size += grid_size_grp;
gemm_desc_(i).OffsetA = gemm_shapes_[i].OffsetA;
gemm_desc_(i).OffsetB = gemm_shapes_[i].OffsetB;
gemm_desc_(i).OffsetC = gemm_shapes_[i].OffsetC;
if(GridwiseGemm::CheckValidity(gemm_desc_[i].a_grid_desc_k0_m_k1_,
gemm_desc_[i].b_grid_desc_k0_n_k1_,
gemm_desc_[i].c_grid_desc_m_n_,
M01_,
N01_))
{
gemm_desc_(i).c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(
gemm_desc_[i].c_grid_desc_m_n_);
for(index_t i = 0; i < gemm_shapes.size(); i++)
{
const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N;
const index_t K = gemm_shapes[i].K;
gemm_desc_(i).block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
gemm_desc_[i].c_grid_desc_m_n_, M01, N01);
}
}
else
const index_t StrideA = gemm_shapes[i].StrideA;
const index_t StrideB = gemm_shapes[i].StrideB;
const index_t StrideC = gemm_shapes[i].StrideC;
const auto a_grid_desc_k0_m_k1_ =
DeviceGroupedGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1_ =
DeviceGroupedGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
const auto c_grid_desc_m_n_ =
DeviceGroupedGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC);
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size;
const index_t BlockEnd = grid_size + grid_size_grp;
grid_size += grid_size_grp;
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
{
gemm_desc_(i).BlockStart = -1;
gemm_desc_(i).BlockEnd = -1;
const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
const auto block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
GemmShape_.push_back(GemmDesc{a_grid_desc_k0_m_k1_,
b_grid_desc_k0_n_k1_,
c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_,
static_cast<const ADataType*>(gemm_shapes[i].p_a),
static_cast<const BDataType*>(gemm_shapes[i].p_b),
static_cast<CDataType*>(gemm_shapes[i].p_c),
BlockStart,
BlockEnd});
}
});
}
}
// private:
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<gemm_desc> gemm_shapes_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_;
std::vector<GemmDesc> GemmShape_;
index_t grid_size;
};
......@@ -337,44 +322,51 @@ struct DeviceGroupedGemmXdl
float Run(const Argument& arg, int nrepeat = 1)
{
StaticallyIndexedArray<GemmDesc, MaxGroupCount> GemmShape_arg;
bool has_main_k0_block_loop = true;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(i < arg.gemm_shapes_.size())
if(i < arg.GemmShape_.size())
{
GemmShape_arg(i) = arg.GemmShape_[i];
std::cout << "arg.a_grid_desc_k0_m_k1_{"
<< arg.gemm_desc_[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
<< GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{"
<< arg.gemm_desc_[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< arg.gemm_desc_[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.gemm_desc_[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< GemmShape_arg[i].b_grid_desc_k0_n_k1_.GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_{ "
<< arg.gemm_desc_[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.gemm_desc_[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< GemmShape_arg[i].c_grid_desc_m_n_.GetLength(I0) << ", "
<< GemmShape_arg[i].c_grid_desc_m_n_.GetLength(I1) << "}"
<< std::endl;
std::cout << "Block: " << arg.gemm_desc_[i].BlockStart << ", "
<< arg.gemm_desc_[i].BlockEnd << std::endl;
if(!GridwiseGemm::CheckValidity(arg.gemm_desc_[i].a_grid_desc_k0_m_k1_,
arg.gemm_desc_[i].b_grid_desc_k0_n_k1_,
arg.gemm_desc_[i].c_grid_desc_m_n_,
if(!GridwiseGemm::CheckValidity(GemmShape_arg[i].a_grid_desc_k0_m_k1_,
GemmShape_arg[i].b_grid_desc_k0_n_k1_,
GemmShape_arg[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
}
});
const auto K0 = arg.gemm_desc_[I0].a_grid_desc_k0_m_k1_.GetLength(I0);
const auto K0 = GemmShape_arg[i].a_grid_desc_k0_m_k1_.GetLength(I0);
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
}
}
});
float ave_time = 0;
......@@ -396,11 +388,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.gemm_desc_,
arg.gemm_shapes_.size(),
GemmShape_arg,
arg.GemmShape_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
......@@ -423,11 +412,8 @@ struct DeviceGroupedGemmXdl
dim3(arg.grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.gemm_desc_,
arg.gemm_shapes_.size(),
GemmShape_arg,
arg.GemmShape_.size(),
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
......@@ -451,9 +437,9 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.gemm_desc_[Number<0>{}].a_grid_desc_k0_m_k1_,
arg.gemm_desc_[Number<0>{}].b_grid_desc_k0_n_k1_,
arg.gemm_desc_[Number<0>{}].c_grid_desc_m_n_,
return GridwiseGemm::CheckValidity(arg.GemmShape_[0].a_grid_desc_k0_m_k1_,
arg.GemmShape_[0].b_grid_desc_k0_n_k1_,
arg.GemmShape_[0].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
......@@ -464,38 +450,25 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
std::vector<gemm_desc> gemm_shapes,
static auto MakeArgument(std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
return Argument{gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<gemm_desc> gemm_shapes,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
gemm_shapes,
1,
1,
a_element_op,
b_element_op,
c_element_op);
return std::make_unique<Argument>(
gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
}
// polymorphic
......
......@@ -26,9 +26,6 @@ __global__ void
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_grouped_gemm_xdlops_v2r3(
const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const StaticallyIndexedArray<GemmDesc, MaxGroupCount> gemm_desc_,
const index_t group_count,
const AElementwiseOperation a_element_op,
......@@ -41,18 +38,16 @@ __global__ void
#if 1
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd)
if(block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
{
auto group_id = i;
const index_t block_id_grp = block_id - gemm_desc_[group_id].BlockStart;
const index_t a_offset_grp = gemm_desc_[group_id].OffsetA;
const index_t b_offset_grp = gemm_desc_[group_id].OffsetB;
const index_t c_offset_grp = gemm_desc_[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
gemm_desc_[group_id].a_ptr,
gemm_desc_[group_id].b_ptr,
gemm_desc_[group_id].c_ptr,
p_shared,
gemm_desc_[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_[group_id].b_grid_desc_k0_n_k1_,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment