Unverified Commit 0ffe956a authored by rocking5566's avatar rocking5566 Committed by GitHub
Browse files

Gemm reduce max (#209)



* [What] Rename the example
[Why] Prepare to add unary reduction

* Add global oparation to the parameter

* Add atomicmax

* Fix compile error

* Support atomicMax (hip library)

* Rename the reduction example

* Fix target name

* use p_d1_grid as the indicator directly

* Prevent performance issue. Let passthrough handle it.

* Implement the function template the specialize the float2

* No need to separate into two lines

* Remove empty line

* add comment

* Fix compile error due to merge from develop

* make the implementation of atomic_max / atomic_add explicit for each datatype

* Refine typo

* For future CI test

* Fix compiler error in ckProfiler

* Merge commit 'de2769e3a6695b38a20529261273ddc5cdaab2fe'

* simply use remove_pointer

* Rename type and var

* Refine example

* Modify reducemax example

* Fix bug in reduction

* Change initialize range

* Implement F64 version of atomicMax

* Move reduction  code together

* Add buffer atomic_max

* Fix coding style by clang-format

* Integrate new api of DeviceGemmReduce_Xdl_CShuffle

* Integrate Batch gemm reduction

* Fix example

* fix example

* clean up

* Fix batch gemm tensor operation

* Fix coding style

* Fix template augument

* Fix clang format

* Keep flexible of different stride for each D tensor

* Fix compile error for ckProfiler

* Fix typo

* [What] Fix naming
[Why] Prepare to add out elementop

* Add DoutElementOp
Co-authored-by: default avatarChao Liu <chao.liu2@amd.com>
Co-authored-by: default avatarrocking <chunylai@amd.com>
parent aafc3ac2
...@@ -17,11 +17,21 @@ namespace tensor_operation { ...@@ -17,11 +17,21 @@ namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::UnarySquare<float, float, false>>; DInElementOps,
DOutElementOps>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
...@@ -119,19 +129,25 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -119,19 +129,25 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>; using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>; using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
const auto a_element_op = AElementOp{}; using UnarySquareElementOp =
const auto b_element_op = BElementOp{}; ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
const auto c_element_op = CElementOp{}; using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
const auto d0_reduce_op = D0ReduceOp{}; using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
const auto d1_reduce_op = D1ReduceOp{};
const auto d1_element_op = D1ElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
if(do_verification) if(do_verification)
{ {
...@@ -163,7 +179,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -163,7 +179,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n)); float d0_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n));
float d1_val; float d1_val;
d1_element_op(d1_val, d0_val); UnarySquareElementOp{}(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val); d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); d1_reduce_op(d1_acc, d1_val);
} }
...@@ -180,6 +196,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -180,6 +196,9 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
...@@ -241,8 +260,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -241,8 +260,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), dxs_global,
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -252,7 +270,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -252,7 +270,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d1_element_op, dxs_in_element_op,
dxs_out_element_op,
BatchCount); BatchCount);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
...@@ -16,11 +16,21 @@ namespace tensor_operation { ...@@ -16,11 +16,21 @@ namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace device_gemm_instance {
using F32 = float;
using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>;
using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr<
DPtrsGlobal,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::UnarySquare<float, float, false>>; DInElementOps,
DOutElementOps>;
void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances( void add_device_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
...@@ -112,19 +122,25 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -112,19 +122,25 @@ bool profile_gemm_reduce_impl(int do_verification,
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add<float>; using D0ReduceOp = ck::reduce::Add<float>;
using D1ReduceOp = ck::reduce::Add<float>; using D1ReduceOp = ck::reduce::Add<float>;
using D1ElementOp = ck::tensor_operation::element_wise::UnarySquare<float, float, false>; using UnaryIdenticElementOp =
ck::tensor_operation::element_wise::UnaryIdentic<float, float, false>;
const auto a_element_op = AElementOp{}; using UnarySquareElementOp =
const auto b_element_op = BElementOp{}; ck::tensor_operation::element_wise::UnarySquare<float, float, false>;
const auto c_element_op = CElementOp{}; using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
const auto d0_reduce_op = D0ReduceOp{}; using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
const auto d1_reduce_op = D1ReduceOp{};
const auto d1_element_op = D1ElementOp{}; const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{};
if(do_verification) if(do_verification)
{ {
...@@ -149,7 +165,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -149,7 +165,7 @@ bool profile_gemm_reduce_impl(int do_verification,
float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n)); float d0_val = ck::type_convert<float>(c_m_n_host_result(m, n));
float d1_val; float d1_val;
d1_element_op(d1_val, d0_val); UnarySquareElementOp{}(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val); d0_reduce_op(d0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); d1_reduce_op(d1_acc, d1_val);
} }
...@@ -165,6 +181,9 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -165,6 +181,9 @@ bool profile_gemm_reduce_impl(int do_verification,
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace());
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()));
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
...@@ -226,8 +245,7 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -226,8 +245,7 @@ bool profile_gemm_reduce_impl(int do_verification,
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), dxs_global,
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
...@@ -237,7 +255,8 @@ bool profile_gemm_reduce_impl(int do_verification, ...@@ -237,7 +255,8 @@ bool profile_gemm_reduce_impl(int do_verification,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
d1_element_op); dxs_in_element_op,
dxs_out_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment