Commit bb9c4a89 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed comments: reserve, seperate ptr and gemm_shapes

parent c3952566
......@@ -79,7 +79,11 @@ int main(int argc, char* argv[])
int group_count = 4;
// GEMM shape
std::vector<ck::GemmShape> gemm_shapes;
std::vector<ck::tensor_operation::device::GemmShape> gemm_shapes;
std::vector<const void*> p_a, p_b;
std::vector<void*> p_c;
gemm_shapes.reserve(group_count);
for(int i = 0; i < group_count; i++)
{
......@@ -87,7 +91,7 @@ int main(int argc, char* argv[])
int N = 128 + 128 * i;
int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
gemm_shapes.push_back({M, N, K, K, K, N});
}
auto f_host_tensor_descriptor =
......@@ -105,14 +109,24 @@ int main(int argc, char* argv[])
};
std::vector<Tensor<ADataType>> a_tensors;
;
std::vector<Tensor<BDataType>> b_tensors;
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
a_tensors.reserve(group_count);
b_tensors.reserve(group_count);
c_host_tensors.reserve(group_count);
c_device_tensors.reserve(group_count);
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
a_tensors_device.reserve(group_count);
b_tensors_device.reserve(group_count);
c_tensors_device.reserve(group_count);
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
......@@ -164,9 +178,9 @@ int main(int argc, char* argv[])
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
gemm_shapes[i].p_a = a_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_b = b_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_c = c_tensors_device[i]->GetDeviceBuffer();
p_a.push_back(a_tensors_device[i]->GetDeviceBuffer());
p_b.push_back(b_tensors_device[i]->GetDeviceBuffer());
p_c.push_back(c_tensors_device[i]->GetDeviceBuffer());
}
auto a_element_op = AElementOp{};
......@@ -174,9 +188,10 @@ int main(int argc, char* argv[])
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(gemm_shapes, a_element_op, b_element_op, c_element_op);
auto gemm = DeviceGemmInstance{};
auto invoker = gemm.MakeInvoker();
auto argument =
gemm.MakeArgument(p_a, p_b, p_c, gemm_shapes, a_element_op, b_element_op, c_element_op);
if(!gemm.IsSupportedArgument(argument))
{
......
......@@ -177,12 +177,5 @@ enum ActivTypeEnum_t
using index_t = int32_t;
using long_index_t = int64_t;
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
void *p_a, *p_b, *p_c;
};
} // namespace ck
#endif
......@@ -8,6 +8,12 @@ namespace ck {
namespace tensor_operation {
namespace device {
struct GemmShape
{
ck::index_t M, N, K;
ck::index_t StrideA, StrideB, StrideC;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
......@@ -70,7 +76,10 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -244,7 +244,10 @@ struct DeviceGroupedGemmXdl
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<GemmShape>& gemm_shapes,
Argument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -256,12 +259,20 @@ struct DeviceGroupedGemmXdl
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
grid_size = 0;
group_count_ = 0;
grid_size_ = 0;
group_count_ = static_cast<int>(gemm_shapes.size());
if(!(group_count_ == p_a.size() && group_count_ == p_b.size() &&
group_count_ == p_c.size()))
{
throw std::runtime_error("wrong! group_count_ != P_a/b/c.size");
}
gemm_desc_kernel_arg_.reserve(group_count_);
for(index_t i = 0; i < gemm_shapes.size(); i++)
{
group_count_++;
const index_t M = gemm_shapes[i].M;
const index_t N = gemm_shapes[i].N;
const index_t K = gemm_shapes[i].K;
......@@ -279,10 +290,10 @@ struct DeviceGroupedGemmXdl
const index_t grid_size_grp = GridwiseGemm::CalculateGridSize(c_grid_desc_m_n_);
const index_t BlockStart = grid_size;
const index_t BlockEnd = grid_size + grid_size_grp;
const index_t BlockStart = grid_size_;
const index_t BlockEnd = grid_size_ + grid_size_grp;
grid_size += grid_size_grp;
grid_size_ += grid_size_grp;
if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
......@@ -299,9 +310,9 @@ struct DeviceGroupedGemmXdl
c_grid_desc_m_n_,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_,
block_2_ctile_map_,
static_cast<const ADataType*>(gemm_shapes[i].p_a),
static_cast<const BDataType*>(gemm_shapes[i].p_b),
static_cast<CDataType*>(gemm_shapes[i].p_c),
static_cast<const ADataType*>(p_a[i]),
static_cast<const BDataType*>(p_b[i]),
static_cast<CDataType*>(p_c[i]),
BlockStart,
BlockEnd});
}
......@@ -318,7 +329,7 @@ struct DeviceGroupedGemmXdl
std::vector<GemmDescKernelArg> gemm_desc_kernel_arg_;
index_t grid_size;
index_t grid_size_;
};
// Invoker
......@@ -395,7 +406,7 @@ struct DeviceGroupedGemmXdl
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.grid_size),
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_arg_arg,
......@@ -419,7 +430,7 @@ struct DeviceGroupedGemmXdl
ave_time = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.grid_size),
dim3(arg.grid_size_),
dim3(BlockSize),
0,
gemm_desc_kernel_arg_arg,
......@@ -459,25 +470,31 @@ struct DeviceGroupedGemmXdl
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(std::vector<GemmShape> gemm_shapes,
static auto MakeArgument(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape> gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
return Argument{p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<const void*> p_a,
std::vector<const void*> p_b,
std::vector<void*> p_c,
std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t /* KBatch */ = 1) override
{
return std::make_unique<Argument>(
gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
p_a, p_b, p_c, gemm_shapes, 1, 1, a_element_op, b_element_op, c_element_op);
}
// polymorphic
......
......@@ -85,7 +85,6 @@ __global__ void
c_element_op,
gemm_desc_ptr[group_id].block_2_ctile_map_,
block_id_grp);
#endif
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment