Commit beff3933 authored by Jing Zhang's avatar Jing Zhang
Browse files

implemented reduce fun in Gemm DeviceOp

parent c9d9e24d
......@@ -489,6 +489,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
void* p_output,
const StreamConfig& stream_config = StreamConfig{})
{
// Reduce larger dim first to reduce workspace size
const std::array<int, 1> reduceDims_1 = {arrInLengths_1[0] > arrInLengths_1[1] ? 0 : 1};
const std::array<int, 1> reduceDims_2 = {0};
......@@ -578,9 +579,11 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
float kern_time = 0;
ADataType amax_a, amax_b;
auto reduce_a = Reduce2D<ADataType>{};
kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_},
is_same<RowMajor, ALayout>::value
is_same<RowMajor, ALayout>::value // A[M, K]
? std::array<index_t, 2>{arg.KRaw_, I1}
: std::array<index_t, 2>{I1, arg.MRaw_},
arg.p_a_grid_,
......@@ -588,9 +591,15 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_a,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
auto reduce_b = Reduce2D<BDataType>{};
kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_},
is_same<RowMajor, BLayout>::value
is_same<RowMajor, BLayout>::value // B[K, N]
? std::array<index_t, 2>{arg.NRaw_, I1}
: std::array<index_t, 2>{I1, arg.KRaw_},
arg.p_a_grid_,
......@@ -598,6 +607,14 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg.p_e_grid_,
stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_b,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
// std::cout << "amax_a: " << amax_a << " amax_b: " << amax_b << std::endl;
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment