"include/vscode:/vscode.git/clone" did not exist on "f03a1738d93c8ffccc570e8121e0a261e9950fa6"
Commit e3a4b967 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed mem issue with unique_ptr

parent 8fb2b172
......@@ -81,19 +81,17 @@ int main(int argc, char* argv[])
// GEMM shape
std::vector<ck::GemmShape> gemm_shapes;
int A_size = 0, B_size = 0, C_size = 0;
for(int i = 0; i < group_count; i++)
{
int M = 256 + 256 * i;
int N = 128 + 128 * i;
int K = 64 + 64 * i;
// int M = 256 + 256 * i;
// int N = 128 + 128 * i;
// int K = 64 + 64 * i;
gemm_shapes.push_back({M, N, K, K, K, N, nullptr, nullptr, nullptr});
int M = 3840;
int N = 1024;
int K = 4096;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
gemm_shapes.push_back({M, N, K, K, N, N, nullptr, nullptr, nullptr});
}
auto f_host_tensor_descriptor =
......@@ -115,6 +113,10 @@ int main(int argc, char* argv[])
std::vector<Tensor<CDataType>> c_host_tensors;
std::vector<Tensor<CDataType>> c_device_tensors;
using DeviceMemPtr = std::unique_ptr<DeviceMem>;
std::vector<DeviceMemPtr> a_tensors_device, b_tensors_device, c_tensors_device;
std::size_t flop = 0, num_btype = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
......@@ -133,13 +135,10 @@ int main(int argc, char* argv[])
<< std::endl;
flop += std::size_t(2) * gemm_shapes[i].M * gemm_shapes[i].K * gemm_shapes[i].N;
num_btype += sizeof(ADataType) * gemm_shapes[i].M * gemm_shapes[i].K +
sizeof(BDataType) * gemm_shapes[i].K * gemm_shapes[i].N +
sizeof(CDataType) * gemm_shapes[i].M * gemm_shapes[i].N;
}
num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() +
sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() +
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize();
for(int i = 0; i < gemm_shapes.size(); i++)
{
switch(init_method)
{
case 0: break;
......@@ -157,38 +156,23 @@ int main(int argc, char* argv[])
}
}
DeviceMem a_tensors_device_buf(sizeof(ADataType) * A_size);
DeviceMem b_tensors_device_buf(sizeof(BDataType) * B_size);
DeviceMem c_tensors_device_buf(sizeof(CDataType) * C_size);
std::vector<ADataType> a_tensors_data, b_tensors_data, c_tensors_data;
A_size = 0;
B_size = 0;
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
a_tensors_data.insert(
a_tensors_data.end(), a_tensors[i].mData.begin(), a_tensors[i].mData.end());
b_tensors_data.insert(
b_tensors_data.end(), b_tensors[i].mData.begin(), b_tensors[i].mData.end());
gemm_shapes[i].p_a =
static_cast<ADataType*>(a_tensors_device_buf.GetDeviceBuffer()) + A_size;
gemm_shapes[i].p_b =
static_cast<BDataType*>(b_tensors_device_buf.GetDeviceBuffer()) + B_size;
gemm_shapes[i].p_c =
static_cast<CDataType*>(c_tensors_device_buf.GetDeviceBuffer()) + C_size;
A_size += gemm_shapes[i].M * gemm_shapes[i].K;
B_size += gemm_shapes[i].N * gemm_shapes[i].K;
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
a_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize()));
b_tensors_device.push_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * a_tensors[i].mDesc.GetElementSize()));
c_tensors_device.push_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_device_tensors[i].mDesc.GetElementSize()));
a_tensors_device[i]->ToDevice(a_tensors[i].mData.data());
b_tensors_device[i]->ToDevice(b_tensors[i].mData.data());
gemm_shapes[i].p_a = a_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_b = b_tensors_device[i]->GetDeviceBuffer();
gemm_shapes[i].p_c = c_tensors_device[i]->GetDeviceBuffer();
}
a_tensors_device_buf.ToDevice(a_tensors_data.data());
b_tensors_device_buf.ToDevice(b_tensors_data.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
......@@ -214,24 +198,11 @@ int main(int argc, char* argv[])
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< gemm.GetTypeString() << std::endl;
c_tensors_data.resize(C_size);
c_tensors_device_buf.FromDevice(c_tensors_data.data());
C_size = 0;
for(int i = 0; i < gemm_shapes.size(); i++)
{
memcpy(c_device_tensors[i].mData.data(),
c_tensors_data.data() + C_size,
c_device_tensors[i].mData.size() * sizeof(CDataType));
C_size += gemm_shapes[i].M * gemm_shapes[i].N;
}
if(do_verification)
{
for(int i = 0; i < gemm_shapes.size(); i++)
{
c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data());
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
......
......@@ -70,7 +70,7 @@ template <typename AElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGroupedGemm : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -242,7 +242,7 @@ struct DeviceGroupedGemmXdl
// Argument
struct Argument : public BaseArgument
{
Argument(std::vector<GemmShape> gemm_shapes,
Argument(std::vector<GemmShape>& gemm_shapes,
index_t M01,
index_t N01,
AElementwiseOperation a_element_op,
......@@ -360,8 +360,7 @@ struct DeviceGroupedGemmXdl
if(GridwiseGemm::CalculateHasMainK0BlockLoop(K0) != has_main_k0_block_loop)
{
throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
throw std::runtime_error("wrong! not all gemm has_main_k0_block_loop");
}
}
});
......@@ -435,11 +434,17 @@ struct DeviceGroupedGemmXdl
static bool IsSupportedArgument(const Argument& arg)
{
return GridwiseGemm::CheckValidity(arg.GemmShape_[0].a_grid_desc_k0_m_k1_,
arg.GemmShape_[0].b_grid_desc_k0_n_k1_,
arg.GemmShape_[0].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
bool isValid = true;
for(int i = 0; i < arg.GemmShape_.size(); i++)
{
isValid &= GridwiseGemm::CheckValidity(arg.GemmShape_[i].a_grid_desc_k0_m_k1_,
arg.GemmShape_[i].b_grid_desc_k0_n_k1_,
arg.GemmShape_[i].c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
}
return isValid;
}
// polymorphic
......@@ -459,7 +464,7 @@ struct DeviceGroupedGemmXdl
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape> gemm_shapes,
std::unique_ptr<BaseArgument> MakeArgumentPointer(std::vector<GemmShape>& gemm_shapes,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
......
......@@ -60,24 +60,22 @@ __global__ void
}
});
#else
const GemmDesc* gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
const auto gemm_desc_ptr = reinterpret_cast<const GemmDesc*>(&gemm_desc_);
index_t group_id = 0;
static_for<0, MaxGroupCount, 1>{}([&](auto i) {
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd)
group_id = (block_id >= gemm_desc_[i].BlockStart && block_id < gemm_desc_[i].BlockEnd &&
i < group_count)
? i
: group_id;
});
const index_t block_id_grp = block_id - gemm_desc_ptr[group_id].BlockStart;
const index_t a_offset_grp = gemm_desc_ptr[group_id].OffsetA;
const index_t b_offset_grp = gemm_desc_ptr[group_id].OffsetB;
const index_t c_offset_grp = gemm_desc_ptr[group_id].OffsetC;
GridwiseGemm::template Run<HasMainK0BlockLoop>(
p_a_grid + a_offset_grp,
p_b_grid + b_offset_grp,
p_c_grid + c_offset_grp,
gemm_desc_ptr[group_id].a_ptr,
gemm_desc_ptr[group_id].b_ptr,
gemm_desc_ptr[group_id].c_ptr,
p_shared,
gemm_desc_ptr[group_id].a_grid_desc_k0_m_k1_,
gemm_desc_ptr[group_id].b_grid_desc_k0_n_k1_,
......
......@@ -12,6 +12,7 @@ struct DeviceMem
{
DeviceMem() = delete;
DeviceMem(std::size_t mem_size);
DeviceMem(const DeviceMem& p);
void* GetDeviceBuffer();
void ToDevice(const void* p);
void FromDevice(void* p);
......
......@@ -5,6 +5,12 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
}
DeviceMem::DeviceMem(const DeviceMem& p) : mpDeviceBuf(p.mpDeviceBuf), mMemSize(p.mMemSize)
{
// hipGetErrorString(hipMalloc(static_cast<void**>(&mpDeviceBuf), mMemSize));
// hipGetErrorString(hipMemcpy(mpDeviceBuf, p.mpDeviceBuf, mMemSize, hipMemcpyDeviceToDevice));
}
void* DeviceMem::GetDeviceBuffer() { return mpDeviceBuf; }
void DeviceMem::ToDevice(const void* p)
......
......@@ -23,9 +23,8 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization_t::Default;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
std::tuple<
// clang-format off
using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances = std::tuple<
// clang-format off
//#################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//#################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//#################| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
......@@ -48,13 +47,14 @@ using device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances =
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>,
//DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 16, 16, 4, 8, 16, 16, 1, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 7, 1>
DeviceGroupedGemmXdl< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
// clang-format on
>;
// clang-format on
>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
add_device_operation_instances(instances,
device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances{});
}
} // namespace device_grouped_gemm_instance
......
......@@ -16,15 +16,19 @@ namespace tensor_operation {
namespace device {
namespace device_grouped_gemm_instance {
using DeviceGroupedGemmNoOpPtr =
ck::tensor_operation::device::DeviceGroupedGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
//void add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
using DeviceGroupedGemmNoOpPtr = ck::tensor_operation::device::DeviceGroupedGemmPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
// void
// add_device_grouped_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGroupedGemmNoOpPtr>&);
} // namespace device_grouped_gemm_instance
} // namespace device
......@@ -41,15 +45,15 @@ template <typename ADataType,
typename BLayout,
typename CLayout>
void profile_grouped_gemm_impl(int do_verification,
int init_method,
bool do_log,
int nrepeat,
std::vector<int> Ms,
std::vector<int> Ns,
std::vector<int> Ks,
std::vector<int> StrideAs,
std::vector<int> StrideBs,
std::vector<int> StrideCs)
int init_method,
bool do_log,
int nrepeat,
std::vector<int> Ms,
std::vector<int> Ns,
std::vector<int> Ks,
std::vector<int> StrideAs,
std::vector<int> StrideBs,
std::vector<int> StrideCs)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......@@ -65,41 +69,48 @@ void profile_grouped_gemm_impl(int do_verification,
}
};
std::vector<Tensor<ADataType>> a_m_k;
std::vector<Tensor<BDataType>> b_k_n;
std::vector<Tensor<CDataType>> c_m_n;
std::vector<Tensor<CDataType>> c_m_n_device_results;
// int A_size = 0, B_size = 0, C_size = 0;
for(int i = 0; i < Ms.size(); i++)
{
a_m_k.push_back(Tensor<ADataType>(f_host_tensor_descriptor(
Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(Tensor<BDataType>(f_host_tensor_descriptor(
Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n.push_back(Tensor<CDataType>(f_host_tensor_descriptor(
Ms[i], Ns[i], StrideCs[i], CLayout{})));
a_m_k.push_back(
Tensor<ADataType>(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{})));
b_k_n.push_back(
Tensor<BDataType>(f_host_tensor_descriptor(Ks[i], Ns[i], StrideBs[i], BLayout{})));
c_m_n_device_results.push_back(
Tensor<CDataType>(f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{})));
std::cout << "a_m_k[" << i << "]:" << a_m_k[i].mDesc << std::endl;
std::cout << "b_k_n[" << i << "]:" << b_k_n[i].mDesc << std::endl;
std::cout << "c_m_n[" << i << "]:" << c_m_n[i].mDesc << std::endl;
std::cout << "c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc
<< std::endl;
std::size_t num_thread = std::thread::hardware_concurrency();
switch(init_method)
{
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
case 0: break;
case 1:
a_m_k[i].GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break;
default:
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
}
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -114,28 +125,112 @@ void profile_grouped_gemm_impl(int do_verification,
// }
std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
//DeviceMem a_device_buf(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
//DeviceMem b_device_buf(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
//DeviceMem c_device_buf(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace());
// std::vector<DeviceMem> a_device_buf, b_device_buf, c_device_buf;
std::vector<void*> a_device_buf, b_device_buf, c_device_buf;
// DeviceMem a_device_buf_(sizeof(ADataType) * A_size);
// DeviceMem b_device_buf_(sizeof(BDataType) * B_size);
// DeviceMem c_device_buf_(sizeof(CDataType) * C_size);
// std::vector<ADataType> a_tensors_data;
// std::vector<BDataType> b_tensors_data;
// std::vector<CDataType> c_tensors_data;
std::vector<GemmShape> gemm_shapes;
// A_size = 0;
// B_size = 0;
// C_size = 0;
for(int i = 0; i < Ms.size(); i++)
{
a_device_buf.push_back(DeviceMem(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
b_device_buf.push_back(DeviceMem(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
c_device_buf.push_back(DeviceMem(sizeof(CDataType) * c_m_n[i].mDesc.GetElementSpace()));
a_device_buf[i].ToDevice(a_m_k[i].mData.data());
b_device_buf[i].ToDevice(b_k_n[i].mData.data());
c_device_buf[i].ToDevice(c_m_n[i].mData.data());
// a_tensors_data.insert(a_tensors_data.end(), a_m_k[i].mData.begin(),
// a_m_k[i].mData.end()); b_tensors_data.insert(b_tensors_data.end(),
// b_k_n[i].mData.begin(), b_k_n[i].mData.end());
// c_tensors_data.insert(c_tensors_data.end(), c_m_n_device_results[i].mData.begin(),
// c_m_n_device_results[i].mData.end());
void *a_device_buf_, *b_device_buf_, *c_device_buf_;
hipGetErrorString(hipMalloc(static_cast<void**>(&a_device_buf_),
sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace()));
hipGetErrorString(hipMalloc(static_cast<void**>(&b_device_buf_),
sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace()));
hipGetErrorString(
hipMalloc(static_cast<void**>(&c_device_buf_),
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace()));
// DeviceMem a_device_buf_(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace());
// DeviceMem b_device_buf_(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace());
// DeviceMem c_device_buf_(sizeof(CDataType) *
// c_m_n_device_results[i].mDesc.GetElementSpace());
hipGetErrorString(hipMemcpy(a_device_buf_,
a_m_k[i].mData.data(),
sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
hipGetErrorString(hipMemcpy(b_device_buf_,
b_k_n[i].mData.data(),
sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
hipGetErrorString(
hipMemcpy(c_device_buf_,
c_m_n_device_results[i].mData.data(),
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpace(),
hipMemcpyHostToDevice));
// a_device_buf_.ToDevice(a_m_k[i].mData.data());
// b_device_buf_.ToDevice(b_k_n[i].mData.data());
// c_device_buf_.ToDevice(c_m_n_device_results[i].mData.data());
a_device_buf.push_back(a_device_buf_);
b_device_buf.push_back(b_device_buf_);
c_device_buf.push_back(c_device_buf_);
// a_device_buf.push_back(a_device_buf_);
// b_device_buf.push_back(b_device_buf_);
// c_device_buf.push_back(c_device_buf_);
// gemm_shapes.push_back({Ms[i],
// Ns[i],
// Ks[i],
// StrideAs[i],
// StrideBs[i],
// StrideCs[i],
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer()});
// printf("%p %p %p\n",
// a_device_buf[i].GetDeviceBuffer(),
// b_device_buf[i].GetDeviceBuffer(),
// c_device_buf[i].GetDeviceBuffer());
gemm_shapes.push_back({Ms[i],
Ns[i],
Ks[i],
StrideAs[i],
StrideBs[i],
StrideCs[i],
a_device_buf_,
b_device_buf_,
c_device_buf_});
// A_size += a_m_k[i].mDesc.GetElementSpace();
// B_size += b_k_n[i].mDesc.GetElementSpace();
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
}
// a_device_buf_.ToDevice(a_tensors_data.data());
// b_device_buf_.ToDevice(b_tensors_data.data());
// c_device_buf_.ToDevice(c_tensors_data.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr> gemm_ptrs;
std::vector<
ck::tensor_operation::device::device_grouped_gemm_instance::DeviceGroupedGemmNoOpPtr>
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
......@@ -143,7 +238,6 @@ void profile_grouped_gemm_impl(int do_verification,
{
ck::tensor_operation::device::device_grouped_gemm_instance::
add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
#if 0
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
......@@ -216,24 +310,15 @@ void profile_grouped_gemm_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
#if 0
#if 1
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
gemm_ptr->MakeArgumentPointer(gemm_shapes,
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{},
KBatch);
ck::tensor_operation::element_wise::PassThrough{});
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......@@ -243,6 +328,7 @@ void profile_grouped_gemm_impl(int do_verification,
float ave_time = invoker_ptr->Run(argument_ptr.get(), nrepeat);
#if 0
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......@@ -262,54 +348,36 @@ void profile_grouped_gemm_impl(int do_verification,
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
#endif
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if constexpr(is_same<ADataType, ck::bhalf_t>::value &&
is_same<BDataType, ck::bhalf_t>::value &&
is_same<CDataType, ck::bhalf_t>::value)
{
Tensor<float> a_f32_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<float> b_f32_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<float> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<float> c_m_n_device_f32_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
// c_tensors_data.resize(C_size);
bf16_to_f32_(a_m_k, a_f32_m_k);
bf16_to_f32_(b_k_n, b_f32_k_n);
bf16_to_f32_(c_m_n_device_result, c_m_n_device_f32_result);
// c_device_buf_.FromDevice(c_tensors_data.data());
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
// C_size = 0;
// for(int i = 0; i < gemm_shapes.size(); i++)
//{
// memcpy(c_m_n_device_results[i].mData.data(),
// c_tensors_data.data() + C_size,
// c_m_n_device_results[i].mDesc.GetElementSpace() * sizeof(CDataType));
auto ref_argument = ref_gemm.MakeArgument(a_f32_m_k,
b_f32_k_n,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
// C_size += c_m_n_device_results[i].mDesc.GetElementSpace();
//}
ref_invoker.Run(ref_argument);
for(int i = 0; i < gemm_shapes.size(); i++)
{
hipGetErrorString(hipMemcpy(c_m_n_device_results[i].mData.data(),
c_device_buf[i],
sizeof(CDataType) *
c_m_n_device_results[i].mDesc.GetElementSpace(),
hipMemcpyDeviceToHost));
check_error(c_m_n_host_result, c_m_n_device_f32_result);
// hipGetErrorString(hipFree(c_device_buf[i]));
if(do_log)
{
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
}
}
else
{
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
f_host_tensor_descriptor(Ms[i], Ns[i], StrideCs[i], CLayout{}));
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceGemm<ADataType,
......@@ -322,27 +390,30 @@ void profile_grouped_gemm_impl(int do_verification,
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto ref_argument = ref_gemm.MakeArgument(a_m_k[i],
b_k_n[i],
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
check_error(c_m_n_host_result, c_m_n_device_result);
check_error(c_m_n_host_result, c_m_n_device_results[i]);
if(do_log)
{
// LogRangeAsType<float>(std::cout << "a : ", a_m_k[i].mData, ",")
//<< std::endl;
// LogRangeAsType<float>(std::cout << "b: ", b_k_n[i].mData, ",") <<
// std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
std::cout << "c_device: ", c_m_n_device_results[i].mData, ",")
<< std::endl;
// LogRangeAsType<float>(
// std::cout << "c_host : ", c_m_n_host_result.mData, ",")
//<< std::endl;
}
}
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
......
......@@ -26,7 +26,7 @@ enum GemmDataType
INT8_INT8_INT8, // 3
};
std::vector<int> stringToArray(char *input)
std::vector<int> stringToArray(char* input)
{
std::vector<int> out;
......@@ -34,7 +34,8 @@ std::vector<int> stringToArray(char *input)
std::string item;
while (std::getline(in, item, ',')) {
while(std::getline(in, item, ','))
{
out.push_back(std::stoi(item));
}
......@@ -69,30 +70,33 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto Ms = stringToArray(argv[8]);
const auto Ns = stringToArray(argv[9]);
const auto Ks = stringToArray(argv[10]);
const auto StrideAs = stringToArray(argv[11]);
const auto StrideBs = stringToArray(argv[12]);
const auto StrideCs = stringToArray(argv[13]);
for(int i = 0; i < Ms.size(); i++)
{
std::cout << "M: " << Ms[i] << " N: " << Ns[i] << " K: " << Ks[i] << std::endl;
}
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
ck::profiler::profile_grouped_gemm_impl<ck::half_t,
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(
do_verification,
init_method,
do_log,
nrepeat,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs);
ck::half_t,
ck::half_t,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor,
ck::tensor_layout::gemm::RowMajor>(do_verification,
init_method,
do_log,
nrepeat,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs);
}
#if 0
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment