"tests/pipelines/vscode:/vscode.git/clone" did not exist on "ca783a0f1f4ce8b0a16e6b96a8890edc47489e3a"
Commit 3c171550 authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Batched gemm - messy validation check

parent 71eea17c
...@@ -96,11 +96,13 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -96,11 +96,13 @@ int run_batched_gemm_example(int argc, char* argv[])
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
{ {
return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
{row * col, stride, 1_uz});
} }
else else
{ {
return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); return ck_tile::HostTensorDescriptor({static_cast<size_t>(16), row, col},
{row * col, 1_uz, stride});
} }
}; };
...@@ -194,8 +196,19 @@ int run_batched_gemm_example(int argc, char* argv[]) ...@@ -194,8 +196,19 @@ int run_batched_gemm_example(int argc, char* argv[])
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>( CLayout>(a_m_k_dev_buf,
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); b_k_n_dev_buf,
c_m_n_gpu_buf_ref,
M,
N,
K,
stride_A,
stride_B,
stride_C,
batch_stride_A,
batch_stride_B,
batch_stride_C,
batch_count);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
......
...@@ -29,22 +29,22 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k, ...@@ -29,22 +29,22 @@ CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
const std::size_t N = b_k_n.get_length(1); const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1); const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) { auto f_mn = [&](auto m, auto n, auto b) {
AccDataType v_acc = 0; AccDataType v_acc = 0;
for(std::size_t k = 0; k < K; ++k) for(std::size_t k = 0; k < K; ++k)
{ {
ADataType v_a = a_element_op(a_m_k(m, k)); ADataType v_a = a_element_op(a_m_k(b, m, k));
BDataType v_b = b_element_op(b_k_n(k, n)); BDataType v_b = b_element_op(b_k_n(b, k, n));
v_acc += v_acc +=
ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b); ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
} }
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc)); c_m_n(b, m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
}; };
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); make_ParallelTensorFunctor(f_mn, M, N, 16)(std::thread::hardware_concurrency());
} }
template <typename ADataType, template <typename ADataType,
...@@ -105,16 +105,20 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -105,16 +105,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t K, index_t K,
index_t stride_a, index_t stride_a,
index_t stride_b, index_t stride_b,
index_t stride_c) index_t stride_c,
index_t batch_stride_A,
index_t batch_stride_B,
index_t batch_stride_C,
index_t batch_count)
{ {
ADataType* d_A; ADataType* d_A;
BDataType* d_B; BDataType* d_B;
CDataType* d_C; CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType)); hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType)); hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType)); hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess) if(errA != hipSuccess)
{ {
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA) std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
...@@ -136,15 +140,19 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -136,15 +140,19 @@ void reference_gemm_gpu(DeviceMem& a_device,
return; // Early exit on error return; // Early exit on error
} }
errA = hipMemcpy( errA = hipMemcpy(d_A,
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice); a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess) if(errA != hipSuccess)
{ {
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl; std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
} }
errB = hipMemcpy( errB = hipMemcpy(d_B,
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice); b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess) if(errB != hipSuccess)
{ {
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl; std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
...@@ -154,10 +162,20 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -154,10 +162,20 @@ void reference_gemm_gpu(DeviceMem& a_device,
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC> for(int i = 0; i < batch_count; ++i)
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); {
errC = hipMemcpy( ADataType* d_ATemp = d_A + i * batch_stride_A;
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost); BDataType* d_BTemp = d_B + i * batch_stride_B;
CDataType* d_CTemp = d_C + i * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
}
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess) if(errC != hipSuccess)
{ {
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl; std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
......
...@@ -89,13 +89,20 @@ struct BatchedGemmKernel ...@@ -89,13 +89,20 @@ struct BatchedGemmKernel
CK_TILE_DEVICE void operator()(BatchedGemmCommonKargs kargs) const CK_TILE_DEVICE void operator()(BatchedGemmCommonKargs kargs) const
{ {
const auto [i_m, i_n] = TilePartitioner{}(); const auto [i_m, i_n] = TilePartitioner{}();
// const auto i_k = blockIdx.z; const auto i_k = blockIdx.z;
// options // options
const ADataType* a_start = static_cast<const ADataType*>( const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr) +
kargs.a_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A); __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A);
const BDataType* b_start = static_cast<const BDataType*>( const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr) +
kargs.b_ptr); //+ __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B); __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B);
// Convert pointers to tensor views // Convert pointers to tensor views
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0)
// {
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_A));
// printf("__builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B): %d\n",
// __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_B));
// }
auto a_tensor_view = [&]() { auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{ {
...@@ -172,8 +179,8 @@ struct BatchedGemmKernel ...@@ -172,8 +179,8 @@ struct BatchedGemmKernel
auto c_block_tile = auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr); GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>( CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr) +
kargs.c_ptr); //; + __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C); __builtin_amdgcn_readfirstlane(i_k * kargs.batch_stride_C);
auto c_tensor_view = [&]() { auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment