Commit 6ef4e211 authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into contraction

parents b0a2afb9 9e4429f9
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -20,7 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1); ...@@ -20,7 +20,7 @@ ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1); ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -49,7 +49,7 @@ ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1); ...@@ -49,7 +49,7 @@ ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1); ...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1); ...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1); ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1); ...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1); ...@@ -24,7 +24,7 @@ ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1); ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1); ...@@ -48,7 +48,7 @@ ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -21,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1); ...@@ -21,7 +21,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on // clang-format on
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace instance {
// clang-format off // clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim // InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1); ...@@ -36,7 +36,7 @@ ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1); ADD_THREADWISE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on // clang-format on
} // namespace device_reduce_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
......
...@@ -6,42 +6,42 @@ include_directories(BEFORE ...@@ -6,42 +6,42 @@ include_directories(BEFORE
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp src/profile_gemm_splitk.cpp
src/profile_gemm_bias_relu.cpp src/profile_gemm_bilinear.cpp
src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp
src/profile_gemm_bias_add_reduce.cpp src/profile_gemm_bias_add_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp
src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp src/profile_batched_gemm.cpp
src/profile_batched_gemm_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_fwd_bias_relu.cpp src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp src/profile_conv_fwd_bias_relu_add.cpp
src/profile_convnd_fwd.cpp src/profile_convnd_fwd.cpp
src/profile_convnd_bwd_data.cpp src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp
src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp src/profile_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp src/profile_normalization.cpp
) )
add_executable(ckProfiler ${PROFILER_SOURCE}) add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE conv_util) target_link_libraries(ckProfiler PRIVATE conv_util)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_splitk_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bilinear_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_bias_add_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) target_link_libraries(ckProfiler PRIVATE device_normalization_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
...@@ -7,56 +7,17 @@ ...@@ -7,56 +7,17 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/conv_util.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_batched_gemm_instance {
using DeviceGemmNoOpPtr =
ck::tensor_operation::device::DeviceGemmPtr<ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_batched_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -73,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -73,6 +34,9 @@ bool profile_batched_gemm_impl(int do_verification,
int M, int M,
int N, int N,
int K, int K,
int BatchStrideA,
int BatchStrideB,
int BatchStrideC,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
...@@ -84,46 +48,44 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -84,46 +48,44 @@ bool profile_batched_gemm_impl(int do_verification,
std::size_t row, std::size_t row,
std::size_t col, std::size_t col,
std::size_t stride, std::size_t stride,
std::size_t batch_stride,
auto layout) { auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value) if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({row * stride, stride, 1})); std::vector<std::size_t>({batch_stride, stride, 1}));
} }
else else
{ {
return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}), return HostTensorDescriptor(std::vector<std::size_t>({batch_count, row, col}),
std::vector<std::size_t>({col * stride, 1, stride})); std::vector<std::size_t>({batch_stride, 1, stride}));
} }
}; };
Tensor<ADataType> a_g_m_k(f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); Tensor<ADataType> a_g_m_k(
Tensor<BDataType> b_g_k_n(f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{})); f_host_tensor_descriptor(BatchCount, M, K, StrideA, BatchStrideA, ALayout{}));
Tensor<BDataType> b_g_k_n(
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BatchStrideB, BLayout{}));
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, BatchStrideC, CLayout{}));
std::unique_ptr<Tensor<float>> c_f32_g_m_n_host_result = nullptr;
std::unique_ptr<Tensor<float>> c_f32_g_m_n_device_result = nullptr;
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl; std::cout << "b_g_k_n: " << b_g_k_n.mDesc << std::endl;
std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl; std::cout << "c_g_m_n: " << c_g_m_n_host_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread); a_g_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); b_g_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
break; break;
default: default:
a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread); a_g_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_g_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
} }
// set zero to c_device_buf
c_g_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
...@@ -135,56 +97,21 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -135,56 +97,21 @@ bool profile_batched_gemm_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
if constexpr(is_same<ADataType, ck::bhalf_t>::value && using ReferenceBatchedGemmInstance =
is_same<BDataType, ck::bhalf_t>::value && ck::tensor_operation::host::ReferenceBatchedGemm<ADataType,
is_same<CDataType, ck::bhalf_t>::value) BDataType,
{ CDataType,
Tensor<float> a_f32_g_m_k( AElementOp,
f_host_tensor_descriptor(BatchCount, M, K, StrideA, ALayout{})); BElementOp,
Tensor<float> b_f32_g_k_n( CElementOp>;
f_host_tensor_descriptor(BatchCount, K, N, StrideB, BLayout{}));
c_f32_g_m_n_host_result = std::make_unique<Tensor<float>>(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
c_f32_g_m_n_device_result = std::make_unique<Tensor<float>>(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
bf16_to_f32_(a_g_m_k, a_f32_g_m_k);
bf16_to_f32_(b_g_k_n, b_f32_g_k_n);
using ReferenceBatchedGemmInstance = ck::tensor_operation::host::
ReferenceBatchedGemm<float, float, float, AElementOp, BElementOp, CElementOp>;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker();
auto ref_argument = ref_batched_gemm.MakeArgument(a_f32_g_m_k,
b_f32_g_k_n,
*c_f32_g_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
else
{
using ReferenceBatchedGemmInstance = auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
ck::tensor_operation::host::ReferenceBatchedGemm<ADataType, auto ref_invoker = ref_batched_gemm.MakeInvoker();
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_argument = ref_batched_gemm.MakeArgument(
auto ref_invoker = ref_batched_gemm.MakeInvoker(); a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
auto ref_argument = ref_batched_gemm.MakeArgument( ref_invoker.Run(ref_argument);
a_g_m_k, b_g_k_n, c_g_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
...@@ -195,172 +122,56 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -195,172 +122,56 @@ bool profile_batched_gemm_impl(int do_verification,
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
c_device_buf.ToDevice(c_g_m_n_device_result.mData.data()); c_device_buf.ToDevice(c_g_m_n_device_result.mData.data());
// add device GEMM instances using DeviceOp = ck::tensor_operation::device::DeviceBatchedGemm<ALayout,
std::vector<ck::tensor_operation::device::device_batched_gemm_instance::DeviceGemmNoOpPtr> BLayout,
gemm_ptrs; CLayout,
ADataType,
BDataType,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && // get device op instances
is_same<CDataType, half_t>::value) const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
{ DeviceOp>::GetInstances();
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, bhalf_t>::value && is_same<BDataType, bhalf_t>::value &&
is_same<CDataType, bhalf_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, int8_t>::value && is_same<BDataType, int8_t>::value &&
is_same<CDataType, int8_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_batched_gemm_instance::
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(gemm_ptrs);
}
}
if(gemm_ptrs.size() <= 0) std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name; std::string best_op_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
// profile device GEMM instances // profile device op instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, StrideC,
ck::tensor_operation::element_wise::PassThrough{}, BatchStrideA,
ck::tensor_operation::element_wise::PassThrough{}, BatchStrideB,
ck::tensor_operation::element_wise::PassThrough{}, BatchStrideC,
BatchCount); BatchCount,
ck::tensor_operation::element_wise::PassThrough{},
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); ck::tensor_operation::element_wise::PassThrough{},
ck::tensor_operation::element_wise::PassThrough{});
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string gemm_name = gemm_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
std::string op_name = op_ptr->GetTypeString();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -376,11 +187,11 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -376,11 +187,11 @@ bool profile_batched_gemm_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm_name << std::endl; << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_gemm_name = gemm_name; best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
...@@ -390,20 +201,8 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -390,20 +201,8 @@ bool profile_batched_gemm_impl(int do_verification,
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
if constexpr(is_same<ADataType, ck::bhalf_t>::value && pass = pass &
is_same<BDataType, ck::bhalf_t>::value && ck::utils::check_err(c_g_m_n_device_result.mData, c_g_m_n_host_result.mData);
is_same<CDataType, ck::bhalf_t>::value)
{
bf16_to_f32_(c_g_m_n_device_result, *c_f32_g_m_n_device_result);
float err = check_error(*c_f32_g_m_n_host_result, *c_f32_g_m_n_device_result);
pass = pass && (err < 1E-6);
}
else
{
float err = check_error(c_g_m_n_host_result, c_g_m_n_device_result);
pass = pass && (err < 1E-6);
}
if(do_log) if(do_log)
{ {
...@@ -419,13 +218,12 @@ bool profile_batched_gemm_impl(int do_verification, ...@@ -419,13 +218,12 @@ bool profile_batched_gemm_impl(int do_verification,
} }
else else
{ {
std::cout << "this device GEMM instance does not support this GEMM problem" std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
<< std::endl;
} }
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass; return pass;
} }
......
...@@ -19,22 +19,18 @@ ...@@ -19,22 +19,18 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>; using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Identity = ck::tensor_operation::element_wise::PassThrough; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using ReduceInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using ReduceOutElementOps = ck::Tuple<Identity, Identity>;
using DeviceGemmReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmReducePtr< using DeviceGemmReduceNoOpPtr =
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceGemmReducePtr<0, ReducePtrsGlobal::Size()>;
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
DInElementOps,
DOutElementOps>;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
...@@ -48,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn ...@@ -48,7 +44,7 @@ void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReduceNoOpPtr>&); std::vector<DeviceGemmReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -59,7 +55,7 @@ namespace profiler { ...@@ -59,7 +55,7 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename DDataType, typename ReduceDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
...@@ -99,16 +95,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -99,16 +95,16 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
Tensor<CDataType> c_g_m_n_host_result( Tensor<CDataType> c_g_m_n_host_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_host_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<CDataType> c_g_m_n_device_result( Tensor<CDataType> c_g_m_n_device_result(
f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{})); f_host_tensor_descriptor(BatchCount, M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d0_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>( Tensor<ReduceDataType> d1_g_m_device_result(HostTensorDescriptor(std::vector<std::size_t>(
{static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)}))); {static_cast<std::size_t>(BatchCount), static_cast<std::size_t>(M)})));
std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl; std::cout << "a_g_m_k: " << a_g_m_k.mDesc << std::endl;
...@@ -135,20 +131,23 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -135,20 +131,23 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough; using CElementOp = ck::tensor_operation::element_wise::PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryIdenticElementOp, UnaryIdenticElementOp>;
const auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
const auto dxs_in_element_op = DxsInElementOps{}; std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
const auto dxs_out_element_op = DxsOutElementOps{};
const auto d0_reduce_op = D0ReduceOp{}; const auto reduce0_op = ReduceOp0{};
const auto d1_reduce_op = D1ReduceOp{}; const auto reduce1_op = ReduceOp1{};
auto passthrough = UnaryIdenticElementOp{};
auto square = UnarySquareElementOp{};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&passthrough, &passthrough};
if(do_verification) if(do_verification)
{ {
...@@ -160,6 +159,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -160,6 +159,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
using ReduceAccDataType = ReduceDataType;
auto ref_batched_gemm = ReferenceBatchedGemmInstance{}; auto ref_batched_gemm = ReferenceBatchedGemmInstance{};
auto ref_invoker = ref_batched_gemm.MakeInvoker(); auto ref_invoker = ref_batched_gemm.MakeInvoker();
...@@ -172,21 +173,22 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -172,21 +173,22 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
{ {
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
float d0_acc = d0_reduce_op.GetIdentityValue<float>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
float d1_acc = d1_reduce_op.GetIdentityValue<float>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
float d0_val = ck::type_convert<float>(c_g_m_n_host_result(batch, m, n)); ReduceAccDataType d0_val =
float d1_val; ck::type_convert<ReduceAccDataType>(c_g_m_n_host_result(batch, m, n));
ReduceAccDataType d1_val;
UnarySquareElementOp{}(d1_val, d0_val); square(d1_val, d0_val);
d0_reduce_op(d0_acc, d0_val); reduce0_op(reduce0_acc, d0_val);
d1_reduce_op(d1_acc, d1_val); reduce1_op(reduce1_acc, d1_val);
} }
d0_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d0_acc); d0_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_g_m_host_result(batch, m) = ck::type_convert<DDataType>(d1_acc); d1_g_m_host_result(batch, m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
} }
} }
...@@ -194,18 +196,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -194,18 +196,19 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_g_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_g_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_g_m_n_device_result.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_g_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_g_m_device_result.mDesc.GetElementSpace()); d0_g_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
d1_g_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer())); reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_g_m_k.mData.data()); a_device_buf.ToDevice(a_g_m_k.mData.data());
b_device_buf.ToDevice(b_g_k_n.mData.data()); b_device_buf.ToDevice(b_g_k_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -214,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -214,7 +217,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -222,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -222,7 +225,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gmk_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -230,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -230,7 +233,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -238,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -238,7 +241,7 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
gemm_ptrs); gemm_ptrs);
} }
...@@ -257,31 +260,32 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -257,31 +260,32 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
auto argument_ptr = auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), nullptr,
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {},
&dxs_global, c_device_buf.GetDeviceBuffer(),
M, p_reduces,
N, M,
K, N,
StrideA, K,
StrideB, StrideA,
StrideC, StrideB,
a_element_op, StrideC,
b_element_op, {},
c_element_op, gemm_element_ops,
dxs_in_element_op, {},
dxs_out_element_op, reduce_in_element_ops,
BatchCount); reduce_out_element_ops,
BatchCount);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -311,8 +315,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification, ...@@ -311,8 +315,8 @@ bool profile_batched_gemm_reduce_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_g_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_g_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_g_m_device_result.mData.data()); reduce0_device_buf.FromDevice(d0_g_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_g_m_device_result.mData.data()); reduce1_device_buf.FromDevice(d1_g_m_device_result.mData.data());
float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result); float c_error = check_error(c_g_m_n_host_result, c_g_m_n_device_result);
float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result); float d0_error = check_error(d0_g_m_host_result, d0_g_m_device_result);
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_weight_instance { namespace instance {
using DeviceConvBwdWeightNoOpPtr = using DeviceConvBwdWeightNoOpPtr =
DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdWeightPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances( ...@@ -31,7 +31,7 @@ void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(
void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances( void add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<DeviceConvBwdWeightNoOpPtr>&); std::vector<DeviceConvBwdWeightNoOpPtr>&);
} // namespace device_conv2d_bwd_weight_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification, ...@@ -165,14 +165,14 @@ bool profile_conv_bwd_weight_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, float> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, float>) ck::is_same_v<ck::remove_cv_t<OutDataType>, float>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
} }
else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> && else if constexpr(ck::is_same_v<ck::remove_cv_t<InDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_bwd_weight_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_weight_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_add_instance { namespace instance {
using DeviceConvFwdBiasReluAddPtr = using DeviceConvFwdBiasReluAddPtr =
DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationAddPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluAddPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluAddPtr>&); std::vector<DeviceConvFwdBiasReluAddPtr>&);
} // namespace device_conv2d_fwd_bias_activation_add_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification, ...@@ -179,7 +179,7 @@ void profile_conv_fwd_bias_relu_add_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_add_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_add_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_fwd_bias_activation_instance { namespace instance {
using DeviceConvFwdBiasReluPtr = using DeviceConvFwdBiasReluPtr =
DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvFwdBiasActivationPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr = ...@@ -27,7 +27,7 @@ using DeviceConvFwdBiasReluPtr =
void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances( void add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(
std::vector<DeviceConvFwdBiasReluPtr>&); std::vector<DeviceConvFwdBiasReluPtr>&);
} // namespace device_conv2d_fwd_bias_activation_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification, ...@@ -169,7 +169,7 @@ void profile_conv_fwd_bias_relu_impl(int do_verification,
ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> && ck::is_same_v<ck::remove_cv_t<WeiDataType>, ck::half_t> &&
ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>) ck::is_same_v<ck::remove_cv_t<OutDataType>, ck::half_t>)
{ {
ck::tensor_operation::device::device_conv2d_fwd_bias_activation_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs); add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
} }
......
...@@ -22,7 +22,7 @@ using INT8 = int8_t; ...@@ -22,7 +22,7 @@ using INT8 = int8_t;
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_conv2d_bwd_data_instance { namespace instance {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr =
DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough, DeviceConvBwdDataPtr<ck::tensor_operation::element_wise::PassThrough,
...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances( ...@@ -54,15 +54,14 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances( void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<DeviceConvBwdDataNoOpPtr>&); std::vector<DeviceConvBwdDataNoOpPtr>&);
} // namespace device_conv2d_bwd_data_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
using DeviceConvBwdDataNoOpPtr = using DeviceConvBwdDataNoOpPtr = ck::tensor_operation::device::instance::DeviceConvBwdDataNoOpPtr;
ck::tensor_operation::device::device_conv2d_bwd_data_instance::DeviceConvBwdDataNoOpPtr;
template <typename InLayout> template <typename InLayout>
HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims, HostTensorDescriptor get_input_host_tensor_descriptor(const std::vector<std::size_t>& dims,
...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -144,15 +143,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -165,15 +164,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -186,15 +185,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(conv_ptrs);
break; break;
default: break; default: break;
...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr( ...@@ -207,15 +206,15 @@ void get_device_conv_bwd_data_op_ptr(
switch(num_dim_spatial) switch(num_dim_spatial)
{ {
case 1: case 1:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs); add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(conv_ptrs);
break; break;
case 2: case 2:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs); add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
break; break;
case 3: case 3:
ck::tensor_operation::device::device_conv2d_bwd_data_instance:: ck::tensor_operation::device::instance::
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs); add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(conv_ptrs);
break; break;
default: break; default: break;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
namespace profiler {
int profile_convnd_fwd(int argc, char* argv[]);
} // namespace profiler
} // namespace ck
...@@ -9,38 +9,15 @@ ...@@ -9,38 +9,15 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp"
#include "ck/library/utility/check_err.hpp" #include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp" #include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp" #include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp" #include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/host_tensor/host_conv.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMultipleDPtr<
2,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck { namespace ck {
namespace profiler { namespace profiler {
...@@ -52,21 +29,19 @@ template <typename ADataType, ...@@ -52,21 +29,19 @@ template <typename ADataType,
typename EDataType, typename EDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename D0Layout, typename DELayout> // assume Ds and E have same layout
typename D1Layout, bool profile_gemm_add_add_fastgelu_impl(int do_verification,
typename ELayout> int init_method,
int profile_gemm_add_add_fastgelu_impl(int do_verification, bool /*do_log*/,
int init_method, bool time_kernel,
bool /*do_log*/, int M,
bool time_kernel, int N,
int M, int K,
int N, int StrideA,
int K, int StrideB,
int StrideA, int StrideD0,
int StrideB, int StrideD1,
int StrideD0, int StrideE)
int StrideD1,
int StrideE)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -84,10 +59,10 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -84,10 +59,10 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, DELayout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, DELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, DELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
...@@ -122,48 +97,23 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -122,48 +97,23 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
const auto b_element_op = BElementOp{}; const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{}; const auto cde_element_op = CDEElementOp{};
// add device GEMM instances using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD<
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAddAddFastGeluPtr> ALayout,
device_op_ptrs; BLayout,
DELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddAddFastGelu>;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && // get device op instances
is_same_v<EDataType, half_t>) const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
{ DeviceOp>::GetInstances();
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
device_op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
device_op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
device_op_ptrs);
}
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
device_op_ptrs);
}
}
std::cout << "found " << device_op_ptrs.size() << " instances" << std::endl; std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// run reference // run reference
if(do_verification) if(do_verification)
...@@ -207,7 +157,7 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -207,7 +157,7 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data()); d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data()); d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
std::string best_device_op_name; std::string best_op_name;
float best_ave_time = 0; float best_ave_time = 0;
float best_tflops = 0; float best_tflops = 0;
float best_gb_per_sec = 0; float best_gb_per_sec = 0;
...@@ -215,14 +165,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -215,14 +165,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
bool pass = true; bool pass = true;
// profile device operation instances // profile device operation instances
for(auto& device_op_ptr : device_op_ptrs) for(auto& op_ptr : op_ptrs)
{ {
auto argument_ptr = device_op_ptr->MakeArgumentPointer( auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(), a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(), std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
d1_m_n_device_buf.GetDeviceBuffer()}, d1_m_n_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()), e_device_buf.GetDeviceBuffer(),
M, M,
N, N,
K, K,
...@@ -234,11 +184,11 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -234,11 +184,11 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
b_element_op, b_element_op,
cde_element_op); cde_element_op);
auto invoker_ptr = device_op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string device_op_name = device_op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
if(device_op_ptr->IsSupportedArgument(argument_ptr.get())) if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init E to zero before profiling a kernel // re-init E to zero before profiling a kernel
e_device_buf.SetZero(); e_device_buf.SetZero();
...@@ -256,14 +206,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -256,14 +206,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
float gb_per_sec = num_btype / 1.E6 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << device_op_name << std::endl; << gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops) if(tflops > best_tflops)
{ {
best_device_op_name = device_op_name; best_op_name = op_name;
best_tflops = tflops; best_tflops = tflops;
best_ave_time = ave_time; best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec; best_gb_per_sec = gb_per_sec;
} }
if(do_verification) if(do_verification)
...@@ -276,14 +226,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification, ...@@ -276,14 +226,14 @@ int profile_gemm_add_add_fastgelu_impl(int do_verification,
} }
else else
{ {
std::cout << device_op_name << " does not support this problem" << std::endl; std::cout << op_name << " does not support this problem" << std::endl;
} }
} }
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_device_op_name << std::endl; << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return pass ? 0 : 1; return pass;
} }
} // namespace profiler } // namespace profiler
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_bias.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/host_tensor/device_memory.hpp"
#include "ck/library/host_tensor/host_tensor.hpp"
#include "ck/library/host_tensor/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm_bias_2d.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_gemm_instance {
using DeviceGemmAlphaBetaPtr = ck::tensor_operation::device::DeviceGemmBiasPtr<
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AlphaBetaAdd>;
void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
void add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAlphaBetaPtr>&);
} // namespace device_gemm_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename C0DataType,
typename CDataType,
typename AccDataType,
typename ALayout,
typename BLayout,
typename CLayout>
void profile_gemm_bias_2d_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
float alpha,
float beta)
{
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({stride, 1}));
}
else
{
return HostTensorDescriptor(std::vector<std::size_t>({row, col}),
std::vector<std::size_t>({1, stride}));
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<C0DataType> c0_m_n(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c0_m_n: " << c0_m_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
c0_m_n.GenerateTensorValue(GeneratorTensor_2<C0DataType>{-5, 5}, num_thread);
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
c0_m_n.GenerateTensorValue(GeneratorTensor_3<C0DataType>{-0.5, 0.5}, num_thread);
}
// set zero to c_device_buf
c_m_n_device_result.GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::AlphaBetaAdd;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{alpha, beta};
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemmBias2D<ADataType,
BDataType,
C0DataType,
CDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a_m_k, b_k_n, c0_m_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op);
ref_invoker.Run(ref_argument);
}
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c0_device_buf(sizeof(C0DataType) * c0_m_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
c0_device_buf.ToDevice(c0_m_n.mData.data());
c_device_buf.ToDevice(c_m_n_device_result.mData.data());
// add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAlphaBetaPtr>
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f16_f16_f16_km_nk_mn_instances(gemm_ptrs);
}
}
else if constexpr(is_same<ADataType, float>::value && is_same<BDataType, float>::value &&
is_same<CDataType, float>::value)
{
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_mk_nk_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_kn_mn_instances(gemm_ptrs);
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{
ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_bias_2d_f32_f32_f32_km_nk_mn_instances(gemm_ptrs);
}
}
if(gemm_ptrs.size() <= 0)
{
throw std::runtime_error("wrong! no device GEMM instance found");
}
std::string best_gemm_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs)
{
auto argument_ptr =
gemm_ptr->MakeArgumentPointer(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<C0DataType*>(c0_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
StrideC,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{
std::string gemm_name = gemm_ptr->GetTypeString();
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * M + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << gemm_name << std::endl;
if(tflops > best_tflops)
{
best_gemm_name = gemm_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c0 : ", c0_m_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
}
else
{
std::cout << "does not support this GEMM problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_gemm_name << std::endl;
}
} // namespace profiler
} // namespace ck
...@@ -19,38 +19,33 @@ ...@@ -19,38 +19,33 @@
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_gemm_instance { namespace instance {
using F32 = float; using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using DPtrsGlobal = ck::Tuple<F32*, F32*>; using ReducePtrsGlobal = ck::Tuple<F32*, F32*>;
using Div = ck::tensor_operation::element_wise::UnaryDivide; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::PassThrough; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using ReduceInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using ReduceOutElementOps = ck::Tuple<Div, Div>;
using DeviceGemmBiasAddReduceNoOpPtr = ck::tensor_operation::device::DeviceGemmBiasAddReducePtr< using DeviceGemmBiasAddReduceNoOpPtr =
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::device::DeviceGemmReducePtr<1, ReducePtrsGlobal::Size()>;
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
ck::tensor_operation::element_wise::PassThrough,
DInElementOps,
DOutElementOps>;
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
void add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( void add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmBiasAddReduceNoOpPtr>&); std::vector<DeviceGemmBiasAddReduceNoOpPtr>&);
} // namespace device_gemm_instance } // namespace instance
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -61,9 +56,9 @@ namespace profiler { ...@@ -61,9 +56,9 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename C0DataType, typename BiasDataType,
typename C1DataType, typename D0DataType,
typename DDataType, typename ReduceDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename CLayout>
...@@ -77,7 +72,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -77,7 +72,7 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideC, int StrideC,
int StrideC1) int StrideD0)
{ {
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
return HostTensorDescriptor(std::vector<std::size_t>({len}), return HostTensorDescriptor(std::vector<std::size_t>({len}),
...@@ -102,24 +97,24 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -102,24 +97,24 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<C0DataType> bias_n(f_host_tensor_descriptor1d(N, 1)); Tensor<BiasDataType> bias_n(f_host_tensor_descriptor1d(N, 1));
Tensor<C1DataType> c1_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_host_result( Tensor<ReduceDataType> reduce0_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_host_result( Tensor<ReduceDataType> reduce1_m_host_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{})); Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor2d(M, N, StrideC, CLayout{}));
Tensor<DDataType> d0_m_device_result( Tensor<ReduceDataType> reduce0_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
Tensor<DDataType> d1_m_device_result( Tensor<ReduceDataType> reduce1_m_device_result(
HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)}))); HostTensorDescriptor(std::vector<std::size_t>({static_cast<std::size_t>(M)})));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
std::cout << "d0_m: " << d0_m_host_result.mDesc << std::endl; std::cout << "reduce0_m: " << reduce0_m_host_result.mDesc << std::endl;
std::cout << "d1_m: " << d1_m_host_result.mDesc << std::endl; std::cout << "reduce1_m: " << reduce1_m_host_result.mDesc << std::endl;
std::size_t num_thread = 1; std::size_t num_thread = 1;
switch(init_method) switch(init_method)
...@@ -130,50 +125,53 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -130,50 +125,53 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
bias_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); bias_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
c1_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); d0_m_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread);
break; break;
default: default:
std::srand(0); std::srand(0);
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
bias_n.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5}, num_thread); bias_n.GenerateTensorValue(GeneratorTensor_3<ADataType>{-0.5, 0.5}, num_thread);
c1_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); d0_m_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
} }
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
using C1ElementOp = PassThrough; using D0ElementOp = PassThrough;
using D0ReduceOp = ck::reduce::Add; using ReduceOp0 = ck::reduce::Add;
using D1ReduceOp = ck::reduce::Add; using ReduceOp1 = ck::reduce::Add;
using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide; using UnaryDivElementOp = ck::tensor_operation::element_wise::UnaryDivide;
using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough; using UnaryIdenticElementOp = ck::tensor_operation::element_wise::PassThrough;
using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare; using UnarySquareElementOp = ck::tensor_operation::element_wise::UnarySquare;
using DxsInElementOps = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOps = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
const auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
const auto c1_element_op = C1ElementOp{}; std::array<void*, 3> gemm_element_ops = {&a_element_op, &b_element_op, &c_element_op};
const auto d0_reduce_op = D0ReduceOp{};
const auto d1_reduce_op = D1ReduceOp{}; auto d0_element_op = D0ElementOp{};
const auto reduce0_op = ReduceOp0{};
const auto reduce1_op = ReduceOp1{};
auto dxs_in_element_op = DxsInElementOps{}; auto passthrough = UnaryIdenticElementOp{};
auto dxs_out_element_op = DxsOutElementOps{N, N}; auto square = UnarySquareElementOp{};
auto div = UnaryDivElementOp{N};
std::array<void*, 2> reduce_in_element_ops = {&passthrough, &square};
std::array<void*, 2> reduce_out_element_ops = {&div, &div};
if(do_verification) if(do_verification)
{ {
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
CDataType, CDataType,
DDataType, ReduceDataType,
AElementOp, AElementOp,
BElementOp, BElementOp,
CElementOp>; CElementOp>;
using ReduceAccDataType = DDataType; using ReduceAccDataType = ReduceDataType;
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
...@@ -189,57 +187,56 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -189,57 +187,56 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
ReduceAccDataType acc = static_cast<ReduceAccDataType>(c_m_n_host_result(m, n)) + ReduceAccDataType acc = static_cast<ReduceAccDataType>(c_m_n_host_result(m, n)) +
static_cast<ReduceAccDataType>(bias_n(n)); static_cast<ReduceAccDataType>(bias_n(n));
ReduceAccDataType c1 = static_cast<ReduceAccDataType>(c1_m_n(m, n)); ReduceAccDataType d0 = static_cast<ReduceAccDataType>(d0_m_n(m, n));
c_element_op(acc, acc); c_element_op(acc, acc);
c1_element_op(c1, c1); d0_element_op(d0, d0);
acc += c1; acc += d0;
c_m_n_host_result(m, n) = static_cast<CDataType>(acc); c_m_n_host_result(m, n) = static_cast<CDataType>(acc);
} }
for(int m = 0; m < M; ++m) for(int m = 0; m < M; ++m)
{ {
auto d0_acc = d0_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce0_acc = reduce0_op.GetIdentityValue<ReduceAccDataType>();
auto d1_acc = d1_reduce_op.GetIdentityValue<ReduceAccDataType>(); auto reduce1_acc = reduce1_op.GetIdentityValue<ReduceAccDataType>();
for(int n = 0; n < N; ++n) for(int n = 0; n < N; ++n)
{ {
ReduceAccDataType c_val = ReduceAccDataType d0_val =
ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n)); ck::type_convert<ReduceAccDataType>(c_m_n_host_result(m, n));
ReduceAccDataType d0_val;
ReduceAccDataType d1_val; ReduceAccDataType d1_val;
dxs_in_element_op(ck::Number<0>{})(d0_val, c_val); square(d1_val, d0_val);
dxs_in_element_op(ck::Number<1>{})(d1_val, c_val); reduce0_op(reduce0_acc, d0_val);
d0_reduce_op(d0_acc, d0_val); reduce1_op(reduce1_acc, d1_val);
d1_reduce_op(d1_acc, d1_val);
} }
dxs_out_element_op(ck::Number<0>{})(d0_acc, d0_acc); div(reduce0_acc, reduce0_acc);
dxs_out_element_op(ck::Number<1>{})(d1_acc, d1_acc); div(reduce1_acc, reduce1_acc);
d0_m_host_result(m) = ck::type_convert<DDataType>(d0_acc); reduce0_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce0_acc);
d1_m_host_result(m) = ck::type_convert<DDataType>(d1_acc); reduce1_m_host_result(m) = ck::type_convert<ReduceDataType>(reduce1_acc);
} }
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace());
DeviceMem bias_device_buf(sizeof(C0DataType) * bias_n.mDesc.GetElementSpace()); DeviceMem bias_device_buf(sizeof(BiasDataType) * bias_n.mDesc.GetElementSpace());
DeviceMem c1_device_buf(sizeof(C1DataType) * c1_m_n.mDesc.GetElementSpace()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d0_device_buf(sizeof(DDataType) * d0_m_device_result.mDesc.GetElementSpace()); DeviceMem reduce0_device_buf(sizeof(ReduceDataType) *
DeviceMem d1_device_buf(sizeof(DDataType) * d1_m_device_result.mDesc.GetElementSpace()); reduce0_m_device_result.mDesc.GetElementSpace());
DeviceMem reduce1_device_buf(sizeof(ReduceDataType) *
reduce1_m_device_result.mDesc.GetElementSpace());
auto dxs_global = ck::make_tuple(static_cast<DDataType*>(d0_device_buf.GetDeviceBuffer()), std::array<void*, 2> p_reduces = {reduce0_device_buf.GetDeviceBuffer(),
static_cast<DDataType*>(d1_device_buf.GetDeviceBuffer())); reduce1_device_buf.GetDeviceBuffer()};
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
bias_device_buf.ToDevice(bias_n.mData.data()); bias_device_buf.ToDevice(bias_n.mData.data());
c1_device_buf.ToDevice(c1_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmBiasAddReduceNoOpPtr> std::vector<ck::tensor_operation::device::instance::DeviceGemmBiasAddReduceNoOpPtr> gemm_ptrs;
gemm_ptrs;
if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value && if constexpr(is_same<ADataType, half_t>::value && is_same<BDataType, half_t>::value &&
is_same<CDataType, half_t>::value) is_same<CDataType, half_t>::value)
...@@ -248,32 +245,32 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -248,32 +245,32 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_mk_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::RowMajor>::value && is_same<BLayout, tensor_layout::gemm::RowMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_kn_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value && else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value && is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value &&
is_same<CLayout, tensor_layout::gemm::RowMajor>::value) is_same<CLayout, tensor_layout::gemm::RowMajor>::value)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::instance::
add_device_gemm_bias_add_reduce_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances( add_device_gemm_bias_add_mean_squaremean_xdl_cshuffle_f16_f16_f16_f16_f16_f32_f32_km_nk_mn_instances(
gemm_ptrs); gemm_ptrs);
} }
} }
...@@ -291,34 +288,31 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -291,34 +288,31 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
// profile device GEMM instances // profile device GEMM instances
for(auto& gemm_ptr : gemm_ptrs) for(auto& gemm_ptr : gemm_ptrs)
{ {
auto argument_ptr = gemm_ptr->MakeArgumentPointer( auto argument_ptr = gemm_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), bias_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), {d0_device_buf.GetDeviceBuffer()},
static_cast<C0DataType*>(bias_device_buf.GetDeviceBuffer()), c_device_buf.GetDeviceBuffer(),
static_cast<C1DataType*>(c1_device_buf.GetDeviceBuffer()), p_reduces,
&dxs_global, M,
M, N,
N, K,
K, StrideA,
StrideA, StrideB,
StrideB, StrideC,
StrideC, {StrideD0},
StrideC1, gemm_element_ops,
a_element_op, {&d0_element_op},
b_element_op, reduce_in_element_ops,
c_element_op, reduce_out_element_ops);
c1_element_op,
dxs_in_element_op,
dxs_out_element_op);
auto invoker_ptr = gemm_ptr->MakeInvokerPointer(); auto invoker_ptr = gemm_ptr->MakeInvokerPointer();
if(gemm_ptr->IsSupportedArgument(argument_ptr.get())) if(gemm_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// init DO, D1 to 0 // init DO, D1 to 0
d0_device_buf.SetZero(); reduce0_device_buf.SetZero();
d1_device_buf.SetZero(); reduce1_device_buf.SetZero();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -328,9 +322,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -328,9 +322,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N; std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N + sizeof(C0DataType) * M * N + sizeof(CDataType) * M * N + sizeof(BiasDataType) * M * N +
sizeof(C1DataType) * M * N + sizeof(DDataType) * M + sizeof(D0DataType) * M * N + sizeof(ReduceDataType) * M +
sizeof(DDataType) * M; sizeof(ReduceDataType) * M;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -350,12 +344,12 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -350,12 +344,12 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); c_device_buf.FromDevice(c_m_n_device_result.mData.data());
d0_device_buf.FromDevice(d0_m_device_result.mData.data()); reduce0_device_buf.FromDevice(reduce0_m_device_result.mData.data());
d1_device_buf.FromDevice(d1_m_device_result.mData.data()); reduce1_device_buf.FromDevice(reduce1_m_device_result.mData.data());
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
ck::utils::check_err(d0_m_device_result.mData, d0_m_host_result.mData); ck::utils::check_err(reduce0_m_device_result.mData, reduce0_m_host_result.mData);
ck::utils::check_err(d1_m_device_result.mData, d1_m_host_result.mData); ck::utils::check_err(reduce1_m_device_result.mData, reduce1_m_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -365,13 +359,17 @@ void profile_gemm_bias_add_reduce_impl(int do_verification, ...@@ -365,13 +359,17 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d0_host: ", d0_m_host_result.mData, ",") LogRangeAsType<float>(
std::cout << "d0_host: ", reduce0_m_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d0_device: ", d0_m_device_result.mData, ",") LogRangeAsType<float>(
std::cout << "d0_device: ", reduce0_m_device_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d1_host: ", d1_m_host_result.mData, ",") LogRangeAsType<float>(
std::cout << "d1_host: ", reduce1_m_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>(std::cout << "d1_device: ", d1_m_device_result.mData, ",") LogRangeAsType<float>(
std::cout << "d1_device: ", reduce1_m_device_result.mData, ",")
<< std::endl; << std::endl;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment