Commit bb9c4a89 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments: reserve, seperate ptr and gemm_shapes

parent c3952566
...@@ -79,7 +79,11 @@ int main(int argc, char* argv[]) ...@@ -79,7 +79,11 @@ int main(int argc, char* argv[])
int group_count = 4; int group_count = 4;
// GEMM shape // GEMM shape
std::vector<ck::GemmShape> gemm_shapes; std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
...@@ -87,7 +91,7 @@ int main(int argc, char* argv[]) ...@@ -87,7 +91,7 @@ int main(int argc, char* argv[])
int N = 128 + 128 * i; int N = 128 + 128 * i;
int K = 64 + 64 * i; int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr}); gemm_shapes.push_back({M, N, K, K, K, N});
} }
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
...@@ -105,14 +109,24 @@ int main(int argc, char* argv[]) ...@@ -105,14 +109,24 @@ int main(int argc, char* argv[])
}; };
std::vector<Tensor<ADataType>> a_tensors; std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors; std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors; std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors; std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>; using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device; std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0; std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++) for(int i = 0; i < gemm_shapes.size(); i++)
...@@ -164,9 +178,9 @@ int main(int argc, char* argv[]) ...@@ -164,9 +178,9 @@ int main(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data()); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data()); b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
gemm_shapes[i].p_a = a_tensors_device[i]->GetDeviceBuffer(); p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
gemm_shapes[i].p_b = b_tensors_device[i]->GetDeviceBuffer(); p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
gemm_shapes[i].p_c = c_tensors_device[i]->GetDeviceBuffer(); p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
} }
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
...@@ -174,9 +188,10 @@ int main(int argc, char* argv[]) ...@@ -174,9 +188,10 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
// do GEMM // do GEMM
auto gemm = DeviceGemmInstance{}; auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker(); auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(gemm_shapes, a_element_op, b_element_op, c_element_op); auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!gemm.IsSupportedArgument(argument)) if(!gemm.IsSupportedArgument(argument))
{ {
......
...@@ -177,12 +177,5 @@ enum ActivTypeEnum_t ...@@ -177,12 +177,5 @@ enum ActivTypeEnum_t
using index_t = int32_t; using index_t = int32_t;
using long_index_t = int64_t; using long_index_t = int64_t;
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
void *p_a, *p_b, *p_c;
};
} // namespace ck } // namespace ck
#endif #endif
...@@ -8,6 +8,12 @@ namespace ck { ...@@ -8,6 +8,12 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
};
template <typename AElementwiseOperation, template <typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -70,7 +76,10 @@ template <typename AElementwiseOperation, ...@@ -70,7 +76,10 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator struct DeviceGroupedGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes, virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
......
...@@ -244,7 +244,10 @@ struct DeviceGroupedGemmXdl ...@@ -244,7 +244,10 @@ struct DeviceGroupedGemmXdl
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(std::vector<GemmShape>& gemm_shapes, Argument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
...@@ -256,12 +259,20 @@ struct DeviceGroupedGemmXdl ...@@ -256,12 +259,20 @@ struct DeviceGroupedGemmXdl
b_element_op_{b_element_op}, b_element_op_{b_element_op},
c_element_op_{c_element_op} c_element_op_{c_element_op}
{ {
grid_size = 0; grid_size_ = 0;
group_count_ = 0;
group_count_ = static_cast<int>(gemm_shapes.size());
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() &&
group_count_ == p_c.size()))
{
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
}
gemm_desc_kernel_arg_.reserve(group_count_);
for(index_t i = 0; i < gemm_shapes.size(); i++) for(index_t i = 0; i < gemm_shapes.size(); i++)
{ {
group_count_++;
const index_t M = gemm_shapes[i].M; const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N; const index_t N = gemm_shapes[i].N;
const index_t K = gemm_shapes[i].K; const index_t K = gemm_shapes[i].K;
...@@ -279,10 +290,10 @@ struct DeviceGroupedGemmXdl ...@@ -279,10 +290,10 @@ struct DeviceGroupedGemmXdl
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_); const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size; const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size + grid_size_grp; const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size += grid_size_grp; grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
...@@ -299,9 +310,9 @@ struct DeviceGroupedGemmXdl ...@@ -299,9 +310,9 @@ struct DeviceGroupedGemmXdl
c_grid_desc_m_n_, c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_, block_2_ctile_map_,
static_cast<const ADataType*>(gemm_shapes[i].p_a), static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(gemm_shapes[i].p_b), static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(gemm_shapes[i].p_c), static_cast<CDataType*>(p_c[i]),
BlockStart, BlockStart,
BlockEnd}); BlockEnd});
} }
...@@ -318,7 +329,7 @@ struct DeviceGroupedGemmXdl ...@@ -318,7 +329,7 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_; std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
index_t grid_size; index_t grid_size_;
}; };
// Invoker // Invoker
...@@ -395,7 +406,7 @@ struct DeviceGroupedGemmXdl ...@@ -395,7 +406,7 @@ struct DeviceGroupedGemmXdl
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(arg.grid_size), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_arg_arg,
...@@ -419,7 +430,7 @@ struct DeviceGroupedGemmXdl ...@@ -419,7 +430,7 @@ struct DeviceGroupedGemmXdl
ave_time = launch_and_time_kernel(kernel, ave_time = launch_and_time_kernel(kernel,
nrepeat, nrepeat,
dim3(arg.grid_size), dim3(arg.grid_size_),
dim3(BlockSize), dim3(BlockSize),
0, 0,
gemm_desc_kernel_arg_arg, gemm_desc_kernel_arg_arg,
...@@ -459,25 +470,31 @@ struct DeviceGroupedGemmXdl ...@@ -459,25 +470,31 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg)); return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
} }
static auto MakeArgument(std::vector<GemmShape> gemm_shapes, static auto MakeArgument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op)
{ {
return Argument{gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op}; return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
// polymorphic // polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes, std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override index_t /* KBatch */ = 1) override
{ {
return std::make_unique<Argument>( return std::make_unique<Argument>(
gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op); p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -85,7 +85,6 @@ __global__ void ...@@ -85,7 +85,6 @@ __global__ void
c_element_op, c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_, gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp); block_id_grp);
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment