Commit e92395d9 authored by coderfeli's avatar coderfeli
Browse files

Merge remote-tracking branch 'origin/cka8w8_devtimer' into update_cka8w8_uc

parents 842d910e 7efafa11
...@@ -29,10 +29,6 @@ using BDataType = Types::BDataType; ...@@ -29,10 +29,6 @@ using BDataType = Types::BDataType;
using AccDataType = Types::AccDataType; using AccDataType = Types::AccDataType;
using CDataType = Types::CDataType; using CDataType = Types::CDataType;
struct batched_gemm_kargs : public ck_tile::BatchedGemmHostArgs
{
};
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
...@@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[]) ...@@ -60,4 +56,4 @@ auto create_args(int argc, char* argv[])
} }
// host API // host API
float batched_gemm(batched_gemm_kargs args, const ck_tile::stream_config& s); float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stream_config& s);
...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ...@@ -20,7 +20,7 @@ float invoke_batched_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
int n_warmup, int n_warmup,
int n_repeat) int n_repeat)
{ {
batched_gemm_kargs args; ck_tile::BatchedGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
...@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -188,15 +188,33 @@ int run_batched_gemm_example_with_layouts(int argc,
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType)));
ck_tile::hip_check_error(hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType)));
ck_tile::hip_check_error(hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType)));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice));
ck_tile::reference_batched_gemm_gpu<ADataType, ck_tile::reference_batched_gemm_gpu<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType, CDataType,
ALayout, ALayout,
BLayout, BLayout,
CLayout>(a_m_k_dev_buf, CLayout>(d_A,
b_k_n_dev_buf, d_B,
c_m_n_gpu_buf_ref, d_C,
M, M,
N, N,
K, K,
...@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc, ...@@ -208,6 +226,15 @@ int run_batched_gemm_example_with_layouts(int argc,
batch_stride_C, batch_stride_C,
batch_count); batch_count);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref); pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref);
......
...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; ...@@ -34,13 +34,19 @@ using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs;
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("a_layout", "R", "A tensor data layout - Row by default") arg_parser.insert("Ms", "", "M dimensions - empty by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default") .insert("Ns", "", "N dimensions - empty by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default") .insert("Ks", "", "K dimensions - empty by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU") .insert("stride_As", "", "Tensor A strides - it is empty by default.")
.insert("warmup", "10", "number of iterations before benchmark the kernel") .insert("stride_Bs", "", "Tensor B strides - it is empty by default.")
.insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("stride_Cs", "", "Tensor C strides - it is empty by default.")
.insert("group_count", "16", "group count"); .insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "R", "B tensor data layout - Row by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "10", "number of iterations before benchmark the kernel.")
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
.insert("group_count", "16", "group count.");
bool result = arg_parser.parse(argc, argv); bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser); return std::make_tuple(result, arg_parser);
......
...@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -53,17 +53,24 @@ int run_grouped_gemm_example_with_layouts(int argc,
return -1; return -1;
}; };
auto valid_input_data = [&](int group_count, const auto&... args) {
return !(args.empty() || ...) && group_count == (args.size() == ...);
};
const int group_count = arg_parser.get_int("group_count"); const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat"); const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup"); const int warmup = arg_parser.get_int("warmup");
std::vector<ck_tile::index_t> Ms; std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns; std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks; std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As; std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs; std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs; std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
if(!valid_input_data(group_count, Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs))
{
std::cout << "Please check the input data. Default values will be used." << std::endl;
for(int i = 0; i < group_count; i++) for(int i = 0; i < group_count; i++)
{ {
Ms.push_back(256 + 256 * i); Ms.push_back(256 + 256 * i);
...@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc, ...@@ -74,6 +81,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
stride_Bs.push_back(Ks[i]); stride_Bs.push_back(Ks[i]);
stride_Cs.push_back(Ns[i]); stride_Cs.push_back(Ns[i]);
} }
}
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors; std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors; std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
......
...@@ -111,6 +111,22 @@ ...@@ -111,6 +111,22 @@
#cmakedefine CK_USE_WMMA @CK_USE_WMMA@ #cmakedefine CK_USE_WMMA @CK_USE_WMMA@
#endif #endif
#ifndef CK_USE_GFX94
#cmakedefine CK_USE_GFX94 @CK_USE_GFX94@
#endif
#ifndef CK_USE_OCP_FP8
#cmakedefine CK_USE_OCP_FP8 @CK_USE_OCP_FP8@
#endif
#ifndef CK_USE_FNUZ_FP8
#cmakedefine CK_USE_FNUZ_FP8 @CK_USE_FNUZ_FP8@
#endif
#ifndef CK_USE_FP8_ON_UNSUPPORTED_ARCH
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
// clang-format on // clang-format on
#endif // CK_CONFIG_H_IN #endif // CK_CONFIG_H_IN
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
#pragma once #pragma once
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <hip/hip_ext.h>
#include <set> #include <set>
#include <vector> #include <vector>
...@@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD ...@@ -43,8 +42,8 @@ struct RotatingMemWrapperMultiD
{ {
{ {
void* pADeviceBuf; void* pADeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_)); hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf), hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]), const_cast<void*>(p_a_grids[0]),
size_a_, size_a_,
hipMemcpyDeviceToDevice)); hipMemcpyDeviceToDevice));
...@@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD ...@@ -53,8 +52,8 @@ struct RotatingMemWrapperMultiD
{ {
void* pBDeviceBuf; void* pBDeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_)); hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf), hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]), const_cast<void*>(p_b_grids[0]),
size_b_, size_b_,
hipMemcpyDeviceToDevice)); hipMemcpyDeviceToDevice));
...@@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD ...@@ -66,8 +65,8 @@ struct RotatingMemWrapperMultiD
DsGridPointer ds_buffer; DsGridPointer ds_buffer;
static_for<0, NumDs, 1>{}([&](auto j) { static_for<0, NumDs, 1>{}([&](auto j) {
void* pDDeviceBuf; void* pDDeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j])); hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pDDeviceBuf), hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
static_cast<const void*>(p_ds_grids[0][j]), static_cast<const void*>(p_ds_grids[0][j]),
size_ds_[j], size_ds_[j],
hipMemcpyDeviceToDevice)); hipMemcpyDeviceToDevice));
...@@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD ...@@ -94,10 +93,8 @@ struct RotatingMemWrapperMultiD
} }
void Print() void Print()
{ {
std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b; std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
static_for<0, NumDs, 1>{}( << ", rotating_count: " << rotating_count << "}" << std::endl;
[&](auto j) { std::cout << ", size_d" << j.value << ": " << size_ds[j]; });
std::cout << ", rotating_count: " << rotating_count << "}" << std::endl;
} }
~RotatingMemWrapperMultiD() ~RotatingMemWrapperMultiD()
{ {
...@@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD ...@@ -111,35 +108,13 @@ struct RotatingMemWrapperMultiD
// free device mem // free device mem
for(size_t i = 1; i < rotating_count; i++) for(size_t i = 1; i < rotating_count; i++)
{ {
try hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
{ hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
static_for<0, NumDs, 1>{}([&](auto j) { static_for<0, NumDs, 1>{}([&](auto j) {
using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>; using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
try hip_check_error(
{
HIP_CHECK_ERROR(
hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j])))); hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
}); });
} }
} }
...@@ -176,8 +151,8 @@ struct RotatingMemWrapper ...@@ -176,8 +151,8 @@ struct RotatingMemWrapper
{ {
{ {
void* pADeviceBuf; void* pADeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_)); hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pADeviceBuf), hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
const_cast<void*>(p_a_grids[0]), const_cast<void*>(p_a_grids[0]),
size_a_, size_a_,
hipMemcpyDeviceToDevice)); hipMemcpyDeviceToDevice));
...@@ -186,8 +161,8 @@ struct RotatingMemWrapper ...@@ -186,8 +161,8 @@ struct RotatingMemWrapper
{ {
void* pBDeviceBuf; void* pBDeviceBuf;
HIP_CHECK_ERROR(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_)); hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
HIP_CHECK_ERROR(hipMemcpy(static_cast<void*>(pBDeviceBuf), hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
const_cast<void*>(p_b_grids[0]), const_cast<void*>(p_b_grids[0]),
size_b_, size_b_,
hipMemcpyDeviceToDevice)); hipMemcpyDeviceToDevice));
...@@ -221,23 +196,8 @@ struct RotatingMemWrapper ...@@ -221,23 +196,8 @@ struct RotatingMemWrapper
// free device mem // free device mem
for(size_t i = 1; i < rotating_count; i++) for(size_t i = 1; i < rotating_count; i++)
{ {
try hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
{ hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_a_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
try
{
HIP_CHECK_ERROR(hipFree(const_cast<void*>(p_b_grids[i])));
}
catch(std::runtime_error& re)
{
std::cerr << re.what() << std::endl;
}
} }
} }
} }
...@@ -255,20 +215,25 @@ struct RotatingMemWrapper ...@@ -255,20 +215,25 @@ struct RotatingMemWrapper
inline void flush_icache() inline void flush_icache()
{ {
hipDeviceProp_t deviceProps; hipDeviceProp_t deviceProps;
HIP_CHECK_ERROR(hipGetDeviceProperties(&deviceProps, 0)); hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
int32_t gpu_block3 = deviceProps.multiProcessorCount * 60; int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>(); ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
HIP_CHECK_ERROR(hipGetLastError()); hip_check_error(hipGetLastError());
} }
// if TimePrePress == false, return time does not include preprocess's time // if TimePrePress == false, return time does not include preprocess's time
template <bool TimePreprocess, typename... Args, typename F, typename PreProcessFunc> template <bool TimePreprocess,
typename GemmArgs,
typename... Args,
typename F,
typename PreProcessFunc>
float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
PreProcessFunc preprocess, PreProcessFunc preprocess,
F kernel, F kernel,
dim3 grid_dim, dim3 grid_dim,
dim3 block_dim, dim3 block_dim,
std::size_t lds_byte, std::size_t lds_byte,
GemmArgs& gemm_args,
Args... args) Args... args)
{ {
#if CK_TIME_KERNEL #if CK_TIME_KERNEL
...@@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -291,8 +256,8 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up // warm up
for(int i = 0; i < stream_config.cold_niters_; ++i) for(int i = 0; i < stream_config.cold_niters_; ++i)
{ {
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
HIP_CHECK_ERROR(hipGetLastError()); hip_check_error(hipGetLastError());
} }
const int nrepeat = stream_config.nrepeat_; const int nrepeat = stream_config.nrepeat_;
...@@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -312,36 +277,54 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
#endif #endif
hipEvent_t start, stop; hipEvent_t start, stop;
HIP_CHECK_ERROR(hipEventCreate(&start)); hip_check_error(hipEventCreate(&start));
HIP_CHECK_ERROR(hipEventCreate(&stop)); hip_check_error(hipEventCreate(&stop));
hip_check_error(hipDeviceSynchronize());
hip_check_error(hipEventRecord(start, stream_config.stream_id_));
for(int i = 0; i < nrepeat; ++i) for(int i = 0; i < nrepeat; ++i)
{
if constexpr(!TimePreprocess)
{ {
preprocess(); preprocess();
}
// hipEvent_t start, stop;
// hip_check_error(hipEventCreate(&start));
// hip_check_error(hipEventCreate(&stop));
// hip_check_error(hipDeviceSynchronize());
// hip_check_error(hipEventRecord(start, stream_config.stream_id_));
// calculate preprocess time
if constexpr(TimePreprocess)
{
preprocess();
}
// run real kernel // run real kernel
hipExtLaunchKernelGGL(kernel, kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
grid_dim, hip_check_error(hipGetLastError());
block_dim,
lds_byte,
stream_config.stream_id_,
start,
stop,
0,
args...);
HIP_CHECK_ERROR(hipGetLastError());
// end real kernel // end real kernel
HIP_CHECK_ERROR(hipEventRecord(stop, stream_config.stream_id_)); // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
HIP_CHECK_ERROR(hipEventSynchronize(stop)); // hip_check_error(hipEventSynchronize(stop));
// float cur_time = 0;
// hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
// #if MEDIAN
// times.insert(cur_time);
// #else
// total_time += cur_time;
// #endif
float cur_time = 0; if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
HIP_CHECK_ERROR(hipEventElapsedTime(&cur_time, start, stop)); {
#if MEDIAN // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
times.insert(cur_time);
#else printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
total_time += cur_time; static_cast<const void*>(gemm_args.p_a_grid),
#endif static_cast<const void*>(gemm_args.p_b_grid));
}
} }
hip_check_error(hipEventRecord(stop, stream_config.stream_id_)); hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
hip_check_error(hipEventSynchronize(stop)); hip_check_error(hipEventSynchronize(stop));
...@@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config, ...@@ -367,20 +350,24 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
return (*mid + *mid_next) / 2; return (*mid + *mid_next) / 2;
} }
#else #else
return total_time / nrepeat; // return total_time / nrepeat;
hipDeviceProp_t deviceProps;
hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
return (total_time - preprocess_offset * nrepeat) / nrepeat;
#endif #endif
} }
else else
{ {
preprocess(); preprocess();
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
HIP_CHECK_ERROR(hipGetLastError()); hip_check_error(hipGetLastError());
return 0; return 0;
} }
#else #else
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...); kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
HIP_CHECK_ERROR(hipGetLastError()); hip_check_error(hipGetLastError());
return 0; return 0;
#endif #endif
......
...@@ -477,6 +477,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -477,6 +477,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_thread_buf_tail); b_thread_buf_tail);
}); });
}); });
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
} }
} }
...@@ -692,6 +695,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -692,6 +695,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}); });
}); });
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <string> #include <string>
#include <sstream> #include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp" #include "ck/stream_config.hpp"
...@@ -12,6 +14,34 @@ namespace ck { ...@@ -12,6 +14,34 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
std::smatch match; \
if(!std::regex_search(str, match, obj_name_expr)) \
{ \
return str; \
} \
return std::string(match[1]) + ';'; \
}
#define GET_TEMPLATE_INFO_IMPL \
std::optional<std::string> GetTemplateInfo() const override \
{ \
std::string str = __PRETTY_FUNCTION__; \
static std::regex template_expr{"\\[(.*)\\]"}; \
std::smatch match; \
if(!std::regex_search(str, match, template_expr)) \
{ \
return std::nullopt; \
} \
return std::string(match[1]); \
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
struct BaseArgument struct BaseArgument
{ {
BaseArgument() = default; BaseArgument() = default;
...@@ -48,6 +78,10 @@ struct BaseOperator ...@@ -48,6 +78,10 @@ struct BaseOperator
virtual std::string GetTypeIdName() const { return typeid(*this).name(); } virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
virtual std::string GetTypeIdHashCode() const virtual std::string GetTypeIdHashCode() const
{ {
std::ostringstream oss; std::ostringstream oss;
......
...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator ...@@ -89,7 +89,8 @@ struct DeviceBatchedGemmV2MultiD : public BaseOperator
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0; CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -41,12 +41,15 @@ __global__ void ...@@ -41,12 +41,15 @@ __global__ void
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -54,8 +57,8 @@ __global__ void ...@@ -54,8 +57,8 @@ __global__ void
}); });
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared, p_shared,
...@@ -87,12 +90,15 @@ __global__ void ...@@ -87,12 +90,15 @@ __global__ void
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t g_idx = blockIdx.z % karg.Batch; const index_t g_idx = blockIdx.z % karg.Batch;
const index_t k_idx = blockIdx.z / karg.Batch;
const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx); const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx); const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = karg.compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx); const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
// populate pointer, desc for Ds // populate pointer, desc for Ds
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
// D pointer // D pointer
...@@ -100,8 +106,8 @@ __global__ void ...@@ -100,8 +106,8 @@ __global__ void
}); });
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + a_batch_offset, karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + b_batch_offset, karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid, karg.p_ds_grid,
karg.p_c_grid + c_batch_offset, karg.p_c_grid + c_batch_offset,
p_shared_0, p_shared_0,
...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -303,7 +309,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t Batch_, index_t Batch_,
AElementwiseOperation a_element_op_, AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_, BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_) CElementwiseOperation c_element_op_,
index_t KBatch_)
: GridwiseGemm::Argument{p_a_grid_, : GridwiseGemm::Argument{p_a_grid_,
p_b_grid_, p_b_grid_,
p_ds_grid_, p_ds_grid_,
...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -315,7 +322,7 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
StrideB_, StrideB_,
StrideDs_, StrideDs_,
StrideE_, StrideE_,
1, KBatch_,
a_element_op_, a_element_op_,
b_element_op_, b_element_op_,
c_element_op_}, c_element_op_},
...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -336,13 +343,14 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
arg.Print(); arg.Print();
} }
if(!GridwiseGemm::CheckValidity(arg) || arg.KBatch > 1) if(!GridwiseGemm::CheckValidity(arg))
{ {
throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
} }
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch); std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
float ave_time = 0; float ave_time = 0;
...@@ -387,9 +395,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -387,9 +395,10 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, hipGetErrorString(
hipMemsetAsync(arg_.p_c_grid,
0, 0,
arg_.M * arg_.N * sizeof(CDataType), arg.Batch * arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
}; };
...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -889,7 +898,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) CElementwiseOperation c_element_op,
index_t KBatch = 1)
{ {
return Argument{static_cast<const ADataType*>(p_a), return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -909,7 +919,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op}; c_element_op,
KBatch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -934,7 +945,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
index_t BatchStrideE, index_t BatchStrideE,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3 ...@@ -954,7 +966,8 @@ struct DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
Batch, Batch,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); c_element_op,
KBatch);
} }
// polymorphic // polymorphic
......
...@@ -741,6 +741,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -741,6 +741,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
return str.str(); return str.str();
} }
REGISTER_EXTRA_PRINTING_METHODS
}; };
} // namespace device } // namespace device
......
...@@ -41,7 +41,7 @@ __global__ void ...@@ -41,7 +41,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -76,7 +76,7 @@ __global__ void ...@@ -76,7 +76,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -639,27 +639,27 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead; a_k_split_offset = k_id * karg.KRead;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead; b_k_split_offset = k_id * karg.KRead;
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < karg.KBatch - 1)
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
......
...@@ -18,6 +18,20 @@ ...@@ -18,6 +18,20 @@
#define CK_USE_OCP_FP8 0 #define CK_USE_OCP_FP8 0
#endif #endif
namespace {
// https://en.cppreference.com/w/cpp/types/conditional
template <bool B, class T, class F>
struct conditional
{
using type = T;
};
template <class T, class F>
struct conditional<false, T, F>
{
using type = F;
};
} // namespace
namespace ck { namespace ck {
using f8_fnuz_t = _BitInt(8); using f8_fnuz_t = _BitInt(8);
...@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x) ...@@ -191,11 +205,10 @@ __host__ __device__ static inline T cast_from_f8(fp8_storage_t x)
} }
} }
typename __hip_internal::conditional< typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type retval;
type>::type retval;
if constexpr(we == 5 && is_half && !is_fnuz) if constexpr(we == 5 && is_half && !is_fnuz)
{ {
...@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn ...@@ -538,11 +551,10 @@ __host__ __device__ static inline fp8_storage_t cast_to_f8(T _x, unsigned int rn
constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10); constexpr int mfmt = (sizeof(T) == 8) ? 52 : ((sizeof(T) == 4) ? 23 : 10);
using T_bitwise = typename __hip_internal::conditional< using T_bitwise = typename conditional<
sizeof(T) == 2, sizeof(T) == 2,
unsigned short int, unsigned short int,
typename __hip_internal::conditional<sizeof(T) == 4, unsigned int, unsigned long long>:: typename conditional<sizeof(T) == 4, unsigned int, unsigned long long>::type>::type;
type>::type;
T_bitwise x_bitwise = bit_cast<T_bitwise>(_x); T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
unsigned long long x{x_bitwise}; unsigned long long x{x_bitwise};
......
...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe ...@@ -1303,8 +1303,8 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
static_assert( static_assert(
(std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (std::is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, bf16_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(std::is_same<T, int32_t>::value && (std::is_same<T, int32_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (std::is_same<T, fp8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
......
...@@ -30,7 +30,7 @@ struct meta_data_buffer ...@@ -30,7 +30,7 @@ struct meta_data_buffer
{ {
constexpr index_t size = sizeof(T); constexpr index_t size = sizeof(T);
auto tmp = bit_cast<array<std::byte, size>>(data); auto tmp = ck_tile::bit_cast<array<std::byte, size>>(data);
for(int i = 0; i < size; i++) for(int i = 0; i < size; i++)
{ {
...@@ -66,7 +66,7 @@ struct meta_data_buffer ...@@ -66,7 +66,7 @@ struct meta_data_buffer
pos++; pos++;
} }
data = bit_cast<T>(tmp); data = ck_tile::bit_cast<T>(tmp);
} }
return data; return data;
...@@ -86,7 +86,7 @@ struct meta_data_buffer ...@@ -86,7 +86,7 @@ struct meta_data_buffer
pos++; pos++;
} }
auto data = bit_cast<T>(tmp); auto data = ck_tile::bit_cast<T>(tmp);
return data; return data;
} }
......
...@@ -29,6 +29,7 @@ struct static_distributed_tensor ...@@ -29,6 +29,7 @@ struct static_distributed_tensor
remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>; remove_cvref_t<decltype(StaticTileDistribution{}.get_ys_to_d_descriptor())>;
static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size(); static constexpr index_t kThreadElementSpaceSize = ThreadTensorDesc{}.get_element_space_size();
static_assert(0 < kThreadElementSpaceSize, "Make sure tile distribution is valid");
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension() CK_TILE_HOST_DEVICE static constexpr auto get_num_of_dimension()
{ {
......
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
namespace ck_tile { namespace ck_tile {
/* /*
* a host side utility, arg parser for * a host side utility, arg parser for, either
* -[key0] = [value0, value1, value2]
* or
* -[key0]=[value0] -[key1]=[value1] ... * -[key0]=[value0] -[key1]=[value1] ...
*/ */
class ArgParser class ArgParser
{ {
public: public:
class Arg class Arg
{ {
...@@ -187,6 +190,45 @@ class ArgParser ...@@ -187,6 +190,45 @@ class ArgParser
return value; return value;
} }
std::vector<std::string> get_string_vec(const std::string& name,
const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
std::string s = get_str(name);
std::vector<std::string> tokens;
size_t pos = 0;
std::string token;
while((pos = s.find(delimiter)) != std::string::npos)
{
token = s.substr(0, pos);
tokens.push_back(token);
s.erase(0, pos + delimiter.length());
}
tokens.push_back(s);
return tokens;
}
std::vector<int> get_int_vec(const std::string& name, const std::string& delimiter = ",") const
{
if(get_str(name).empty())
{
return {};
}
const std::vector<std::string> args = get_string_vec(name, delimiter);
std::vector<int> tokens;
tokens.reserve(static_cast<int>(args.size()));
for(const std::string& token : args)
{
int value = atoi(token.c_str());
tokens.push_back(value);
}
return tokens;
}
private: private:
std::unordered_map<std::string, Arg> input_map; std::unordered_map<std::string, Arg> input_map;
std::vector<std::string> keys; std::vector<std::string> keys;
......
...@@ -97,9 +97,9 @@ template <typename ADataType, ...@@ -97,9 +97,9 @@ template <typename ADataType,
typename LayoutA, typename LayoutA,
typename LayoutB, typename LayoutB,
typename LayoutC> typename LayoutC>
void reference_gemm_gpu(DeviceMem& a_device, void reference_gemm_gpu(ADataType* a_ptr,
DeviceMem& b_device, BDataType* b_ptr,
DeviceMem& c_device, CDataType* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device, ...@@ -107,79 +107,13 @@ void reference_gemm_gpu(DeviceMem& a_device,
index_t stride_b, index_t stride_b,
index_t stride_c) index_t stride_c)
{ {
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(
d_A, a_device.GetDeviceBuffer(), M * K * sizeof(ADataType), hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(
d_B, b_device.GetDeviceBuffer(), N * K * sizeof(BDataType), hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N; int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>(d_A, d_B, d_C, M, N, K, stride_a, stride_b, stride_c); <<<numBlocks, numThreadsPerBlock>>>(
errC = hipMemcpy( a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
c_device.GetDeviceBuffer(), d_C, M * N * sizeof(CDataType), hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return; return;
} }
...@@ -191,9 +125,9 @@ template <typename ADataType, ...@@ -191,9 +125,9 @@ template <typename ADataType,
typename LayoutA, typename LayoutA,
typename LayoutB, typename LayoutB,
typename LayoutC> typename LayoutC>
void reference_batched_gemm_gpu(DeviceMem& a_device, void reference_batched_gemm_gpu(ADataType* a_ptr,
DeviceMem& b_device, BDataType* b_ptr,
DeviceMem& c_device, CDataType* c_ptr,
index_t M, index_t M,
index_t N, index_t N,
index_t K, index_t K,
...@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device, ...@@ -205,94 +139,20 @@ void reference_batched_gemm_gpu(DeviceMem& a_device,
index_t batch_stride_C, index_t batch_stride_C,
index_t batch_count) index_t batch_count)
{ {
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
hipError_t errA = hipMalloc(&d_A, batch_count * M * K * sizeof(ADataType));
hipError_t errB = hipMalloc(&d_B, batch_count * N * K * sizeof(BDataType));
hipError_t errC = hipMalloc(&d_C, batch_count * M * N * sizeof(CDataType));
if(errA != hipSuccess)
{
std::cerr << "Error allocating device memory for A: " << hipGetErrorString(errA)
<< std::endl;
return; // Early exit on error
}
if(errB != hipSuccess)
{
std::cerr << "Error allocating device memory for B: " << hipGetErrorString(errB)
<< std::endl;
return; // Early exit on error
}
if(errC != hipSuccess)
{
std::cerr << "Error allocating device memory for C: " << hipGetErrorString(errC)
<< std::endl;
return; // Early exit on error
}
errA = hipMemcpy(d_A,
a_device.GetDeviceBuffer(),
batch_count * M * K * sizeof(ADataType),
hipMemcpyHostToDevice);
if(errA != hipSuccess)
{
std::cerr << "Error copying A to device: " << hipGetErrorString(errA) << std::endl;
}
errB = hipMemcpy(d_B,
b_device.GetDeviceBuffer(),
batch_count * N * K * sizeof(BDataType),
hipMemcpyHostToDevice);
if(errB != hipSuccess)
{
std::cerr << "Error copying B to device: " << hipGetErrorString(errB) << std::endl;
}
int totalElements = M * N; int totalElements = M * N;
int numThreadsPerBlock = 256; // Common choice for threads per block int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock; int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
for(index_t batch_id = 0; batch_id < batch_count; ++batch_id) for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
{ {
ADataType* d_ATemp = d_A + batch_id * batch_stride_A; ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
BDataType* d_BTemp = d_B + batch_id * batch_stride_B; BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
CDataType* d_CTemp = d_C + batch_id * batch_stride_C; CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC> naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
<<<numBlocks, numThreadsPerBlock>>>( <<<numBlocks, numThreadsPerBlock>>>(
d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c); d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
} }
errC = hipMemcpy(c_device.GetDeviceBuffer(),
d_C,
batch_count * M * N * sizeof(CDataType),
hipMemcpyDeviceToHost);
if(errC != hipSuccess)
{
std::cerr << "Error copying C to device: " << hipGetErrorString(errC) << std::endl;
}
errA = hipFree(d_A);
if(errA != hipSuccess)
{
std::cerr << "Error free the A memory: " << hipGetErrorString(errA) << std::endl;
}
errB = hipFree(d_B);
if(errB != hipSuccess)
{
std::cerr << "Error free the B memory: " << hipGetErrorString(errB) << std::endl;
}
errC = hipFree(d_C);
if(errC != hipSuccess)
{
std::cerr << "Error free the C memory: " << hipGetErrorString(errC) << std::endl;
}
return; return;
} }
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp" #include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp" #include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment