"docs/source/en/vscode:/vscode.git/clone" did not exist on "1f02087607aa70948a2546206c58804b59381a6f"
Commit beff3933 authored by Jing Zhang's avatar Jing Zhang
Browse files

implemented reduce fun in Gemm DeviceOp

parent c9d9e24d
...@@ -489,6 +489,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle ...@@ -489,6 +489,7 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
void* p_output, void* p_output,
const StreamConfig& stream_config = StreamConfig{}) const StreamConfig& stream_config = StreamConfig{})
{ {
// Reduce larger dim first to reduce workspace size
const std::array<int, 1> reduceDims_1 = {arrInLengths_1[0] > arrInLengths_1[1] ? 0 : 1}; const std::array<int, 1> reduceDims_1 = {arrInLengths_1[0] > arrInLengths_1[1] ? 0 : 1};
const std::array<int, 1> reduceDims_2 = {0}; const std::array<int, 1> reduceDims_2 = {0};
...@@ -578,9 +579,11 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle ...@@ -578,9 +579,11 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
float kern_time = 0; float kern_time = 0;
ADataType amax_a, amax_b;
auto reduce_a = Reduce2D<ADataType>{}; auto reduce_a = Reduce2D<ADataType>{};
kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_}, kern_time += reduce_a.Run({arg.MRaw_, arg.KRaw_},
is_same<RowMajor, ALayout>::value is_same<RowMajor, ALayout>::value // A[M, K]
? std::array<index_t, 2>{arg.KRaw_, I1} ? std::array<index_t, 2>{arg.KRaw_, I1}
: std::array<index_t, 2>{I1, arg.MRaw_}, : std::array<index_t, 2>{I1, arg.MRaw_},
arg.p_a_grid_, arg.p_a_grid_,
...@@ -588,9 +591,15 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle ...@@ -588,9 +591,15 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg.p_e_grid_, arg.p_e_grid_,
stream_config); stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_a,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
auto reduce_b = Reduce2D<BDataType>{}; auto reduce_b = Reduce2D<BDataType>{};
kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_}, kern_time += reduce_b.Run({arg.KRaw_, arg.NRaw_},
is_same<RowMajor, BLayout>::value is_same<RowMajor, BLayout>::value // B[K, N]
? std::array<index_t, 2>{arg.NRaw_, I1} ? std::array<index_t, 2>{arg.NRaw_, I1}
: std::array<index_t, 2>{I1, arg.KRaw_}, : std::array<index_t, 2>{I1, arg.KRaw_},
arg.p_a_grid_, arg.p_a_grid_,
...@@ -598,6 +607,14 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle ...@@ -598,6 +607,14 @@ struct DeviceGemmMultipleDScaleAB_Xdl_CShuffle
arg.p_e_grid_, arg.p_e_grid_,
stream_config); stream_config);
hipGetErrorString(hipMemcpyWithStream(&amax_b,
arg.p_e_grid_,
sizeof(ADataType),
hipMemcpyDeviceToHost,
stream_config.stream_id_));
// std::cout << "amax_a: " << amax_a << " amax_b: " << amax_b << std::endl;
const index_t grid_size = const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_); arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment