Commit 3d005816 authored by Chao Liu's avatar Chao Liu
Browse files

update example

parent 9551101e
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp16 gemm_add_add_fastgelu_xdl_fp16.cpp)
# Instructions for ```example_gemm_bias_add_fastgelu_xdl_fp16``` # Instructions for ```example_gemm_add_add_fastgelu_xdl_fp16```
## Run ```example_gemm_bias_add_fastgelu_xdl_fp16``` ## Run ```example_gemm_add_add_fastgelu_xdl_fp16```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: arg3: time kernel (0=no, 1=yes) #arg3: time kernel (0=no, 1=yes)
./bin/example_gemm_bias_add_fastgelu_xdl_fp16 1 1 1 #arg4 to 11: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, StrideE"
./bin/example_gemm_add_add_fastgelu_xdl_fp16 1 1 1
``` ```
Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16) Result (MI100 @ 1087Mhz, 133.5TFlops peak FP16)
......
...@@ -36,9 +36,11 @@ using D1DataType = F16; ...@@ -36,9 +36,11 @@ using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using EDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
using ELayout = Row; using D0Layout = Row;
using D1Layout = Row;
using ELayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
...@@ -68,6 +70,7 @@ int main(int argc, char* argv[]) ...@@ -68,6 +70,7 @@ int main(int argc, char* argv[])
ck::index_t StrideA = 4096; ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096; ck::index_t StrideB = 4096;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 4096; ck::index_t StrideD1 = 4096;
ck::index_t StrideE = 4096; ck::index_t StrideE = 4096;
...@@ -81,7 +84,7 @@ int main(int argc, char* argv[]) ...@@ -81,7 +84,7 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]); time_kernel = std::stoi(argv[3]);
} }
else if(argc == 11) else if(argc == 12)
{ {
do_verification = std::stoi(argv[1]); do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]); init_method = std::stoi(argv[2]);
...@@ -93,15 +96,17 @@ int main(int argc, char* argv[]) ...@@ -93,15 +96,17 @@ int main(int argc, char* argv[])
StrideA = std::stoi(argv[7]); StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]); StrideB = std::stoi(argv[8]);
StrideD1 = std::stoi(argv[9]); StrideD0 = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]); StrideD1 = std::stoi(argv[10]);
StrideE = std::stoi(argv[11]);
} }
else else
{ {
printf("arg1: verification (0=no, 1=yes)\n"); printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n"); printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD1, StrideE\n"); printf("arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD0, StrideD1, "
"StrideE\n");
exit(0); exit(0);
} }
...@@ -121,8 +126,8 @@ int main(int argc, char* argv[]) ...@@ -121,8 +126,8 @@ int main(int argc, char* argv[])
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<EDataType> d0_m_n(f_host_tensor_descriptor(M, N, 0, ELayout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<EDataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
...@@ -138,14 +143,14 @@ int main(int argc, char* argv[]) ...@@ -138,14 +143,14 @@ int main(int argc, char* argv[])
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<EDataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0}); d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<EDataType>{0.0, 1.0}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
...@@ -177,7 +182,7 @@ int main(int argc, char* argv[]) ...@@ -177,7 +182,7 @@ int main(int argc, char* argv[])
K, K,
StrideA, StrideA,
StrideB, StrideB,
std::array<ck::index_t, 2>{0, StrideD1}, std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE, StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
...@@ -204,9 +209,8 @@ int main(int argc, char* argv[]) ...@@ -204,9 +209,8 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data()); Tensor<AccDataType> c_m_n(HostTensorDescriptor(
std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
Tensor<AccDataType> c_m_n(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType, BDataType,
...@@ -232,6 +236,8 @@ int main(int argc, char* argv[]) ...@@ -232,6 +236,8 @@ int main(int argc, char* argv[])
} }
} }
e_m_n_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1; return ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData) ? 0 : 1;
} }
......
add_example_executable(example_gemm_bias_add_fastgelu_xdl_fp16 gemm_bias_add_fastgelu_xdl_fp16.cpp)
...@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME) ...@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm) add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_alpha_beta)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_bias_add_fastgelu) add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(06_conv2d_fwd_bias_relu)
add_subdirectory(07_conv2d_fwd_bias_relu_add) add_subdirectory(07_conv2d_fwd_bias_relu_add)
add_subdirectory(09_convnd_fwd) add_subdirectory(09_convnd_fwd)
......
...@@ -25,20 +25,20 @@ include_directories(BEFORE ...@@ -25,20 +25,20 @@ include_directories(BEFORE
set(PROFILER_SOURCE set(PROFILER_SOURCE
src/profiler.cpp src/profiler.cpp
src/profile_gemm.cpp src/profile_gemm.cpp
src/profile_gemm_bias_2d.cpp # src/profile_gemm_bias_2d.cpp
src/profile_gemm_bias_relu.cpp # src/profile_gemm_bias_relu.cpp
src/profile_gemm_bias_relu_add.cpp # src/profile_gemm_bias_relu_add.cpp
src/profile_gemm_reduce.cpp # src/profile_gemm_reduce.cpp
src/profile_batched_gemm.cpp # src/profile_batched_gemm.cpp
src/profile_conv_fwd_bias_relu.cpp # src/profile_conv_fwd_bias_relu.cpp
src/profile_conv_fwd_bias_relu_add.cpp # src/profile_conv_fwd_bias_relu_add.cpp
src/profile_conv_fwd_bias_relu_atomic_add.cpp # src/profile_conv_fwd_bias_relu_atomic_add.cpp
src/profile_convnd_fwd.cpp # src/profile_convnd_fwd.cpp
src/profile_convnd_bwd_data.cpp # src/profile_convnd_bwd_data.cpp
src/profile_reduce.cpp # src/profile_reduce.cpp
src/profile_grouped_gemm.cpp # src/profile_grouped_gemm.cpp
src/profile_conv_bwd_weight.cpp # src/profile_conv_bwd_weight.cpp
src/profile_batched_gemm_reduce.cpp # src/profile_batched_gemm_reduce.cpp
src/profile_gemm_add_add_fastgelu.cpp src/profile_gemm_add_add_fastgelu.cpp
) )
...@@ -46,21 +46,21 @@ add_executable(ckProfiler ${PROFILER_SOURCE}) ...@@ -46,21 +46,21 @@ add_executable(ckProfiler ${PROFILER_SOURCE})
target_link_libraries(ckProfiler PRIVATE host_tensor) target_link_libraries(ckProfiler PRIVATE host_tensor)
target_link_libraries(ckProfiler PRIVATE conv_util) target_link_libraries(ckProfiler PRIVATE conv_util)
target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias2d_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_gemm_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance) #target_link_libraries(ckProfiler PRIVATE device_conv1d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance) #target_link_libraries(ckProfiler PRIVATE device_conv3d_fwd_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_atomic_add_instance)
target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance) #target_link_libraries(ckProfiler PRIVATE device_convnd_bwd_data_instance)
target_link_libraries(ckProfiler PRIVATE device_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance) #target_link_libraries(ckProfiler PRIVATE device_grouped_gemm_instance)
target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance) #target_link_libraries(ckProfiler PRIVATE device_conv2d_bwd_weight_instance)
target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance) #target_link_libraries(ckProfiler PRIVATE device_batched_gemm_reduce_instance)
target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance) target_link_libraries(ckProfiler PRIVATE device_gemm_add_add_fastgelu_instance)
...@@ -25,13 +25,13 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult ...@@ -25,13 +25,13 @@ using DeviceGemmAddAddFastGeluPtr = ck::tensor_operation::device::DeviceGemmMult
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::FastGelu>; ck::tensor_operation::element_wise::FastGelu>;
void add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&); std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&); std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&); std::vector<DeviceGemmAddAddFastGeluPtr>&);
void add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances( void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmAddAddFastGeluPtr>&); std::vector<DeviceGemmAddAddFastGeluPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
...@@ -44,20 +44,26 @@ namespace profiler { ...@@ -44,20 +44,26 @@ namespace profiler {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename D0DataType,
typename D1DataType,
typename EDataType,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename CLayout> typename D0Layout,
int profile_gemm_gelu_impl(int do_verification, typename D1Layout,
int init_method, typename ELayout>
bool do_log, int profile_gemm_add_add_fastgelu_impl(int do_verification,
bool time_kernel, int init_method,
int M, bool do_log,
int N, bool time_kernel,
int K, int M,
int StrideA, int N,
int StrideB, int K,
int StrideC) int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
...@@ -75,65 +81,75 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -75,65 +81,75 @@ int profile_gemm_gelu_impl(int do_verification,
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl; std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
std::size_t num_thread = 1;
switch(init_method) switch(init_method)
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}, num_thread); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
} }
using AElementOp = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
using CElementOp = ck::tensor_operation::element_wise::FastGelu;
const auto a_element_op = AElementOp{}; using AElementOp = PassThrough;
const auto b_element_op = BElementOp{}; using BElementOp = PassThrough;
const auto c_element_op = CElementOp{}; using CDEElementOp = AddAddFastGelu;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
// add device GEMM instances // add device GEMM instances
std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAddAddFastGeluPtr> std::vector<ck::tensor_operation::device::device_gemm_instance::DeviceGemmAddAddFastGeluPtr>
device_op_ptrs; device_op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>) is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> && if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> && is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(device_op_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::RowMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(device_op_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::RowMajor> && is_same_v<BLayout, tensor_layout::gemm::RowMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(device_op_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(device_op_ptrs);
} }
else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> && else if constexpr(is_same_v<ALayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> && is_same_v<BLayout, tensor_layout::gemm::ColumnMajor> &&
is_same_v<CLayout, tensor_layout::gemm::RowMajor>) is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{ {
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(device_op_ptrs); add_device_gemm_gelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(device_op_ptrs);
...@@ -145,23 +161,44 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -145,23 +161,44 @@ int profile_gemm_gelu_impl(int do_verification,
// run reference // run reference
if(do_verification) if(do_verification)
{ {
using ReferenceOpInstance = ck::tensor_operation::host:: Tensor<AccDataType> c_m_n(HostTensorDescriptor(
ReferenceGemm<ADataType, BDataType, CDataType, AElementOp, BElementOp, CElementOp>; std::vector<std::size_t>{static_cast<std::size_t>(M), static_cast<std::size_t>(N)}));
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough>;
auto ref_op = ReferenceOpInstance{}; auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_op.MakeInvoker(); auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_op.MakeArgument(
a_m_k, b_k_n, c_m_n_host_result, a_element_op, b_element_op, c_element_op); auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{});
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
} }
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpace());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpace()); DeviceMem d0_m_n_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpace());
DeviceMem d1_m_n_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpace());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpace());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d0_m_n_device_buf.ToDevice(d0_m_n.mData.data());
d1_m_n_device_buf.ToDevice(d1_m_n.mData.data());
std::string best_device_op_name; std::string best_device_op_name;
float best_ave_time = 0; float best_ave_time = 0;
...@@ -174,18 +211,21 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -174,18 +211,21 @@ int profile_gemm_gelu_impl(int do_verification,
for(auto& device_op_ptr : device_op_ptrs) for(auto& device_op_ptr : device_op_ptrs)
{ {
auto argument_ptr = device_op_ptr->MakeArgumentPointer( auto argument_ptr = device_op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()), a_device_buf.GetDeviceBuffer(),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()), b_device_buf.GetDeviceBuffer(),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()), std::array<const void*, 2>{d0_m_n_device_buf.GetDeviceBuffer(),
d1_m_n_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(e_device_buf.GetDeviceBuffer()),
M, M,
N, N,
K, K,
StrideA, StrideA,
StrideB, StrideB,
StrideC, std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op); cde_element_op);
auto invoker_ptr = device_op_ptr->MakeInvokerPointer(); auto invoker_ptr = device_op_ptr->MakeInvokerPointer();
...@@ -193,8 +233,8 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -193,8 +233,8 @@ int profile_gemm_gelu_impl(int do_verification,
if(device_op_ptr->IsSupportedArgument(argument_ptr.get())) if(device_op_ptr->IsSupportedArgument(argument_ptr.get()))
{ {
// re-init C to zero before profiling a kernel // re-init E to zero before profiling a kernel
c_device_buf.SetZero(); e_device_buf.SetZero();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
...@@ -202,7 +242,7 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -202,7 +242,7 @@ int profile_gemm_gelu_impl(int do_verification,
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time; float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
...@@ -221,20 +261,10 @@ int profile_gemm_gelu_impl(int do_verification, ...@@ -221,20 +261,10 @@ int profile_gemm_gelu_impl(int do_verification,
if(do_verification) if(do_verification)
{ {
c_device_buf.FromDevice(c_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
pass = pass && pass = pass &&
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(e_m_n_device_result.mData, e_m_n_host_result.mData);
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
} }
} }
else else
......
...@@ -8,24 +8,24 @@ ...@@ -8,24 +8,24 @@
int profile_gemm_add_add_fastgelu(int argc, char* argv[]) int profile_gemm_add_add_fastgelu(int argc, char* argv[])
{ {
enum struct GemmMatrixLayout enum struct MatrixLayout
{ {
MK_KN_MN, // 0 MK_KN_MN_MN_MN, // 0
MK_NK_MN, // 1 MK_NK_MN_MN_MN, // 1
KM_KN_MN, // 2 KM_KN_MN_MN_MN, // 2
KM_NK_MN, // 3 KM_NK_MN_MN_MN, // 3
MK_KN_NM, // 4 MK_KN_NM_MN_MN, // 4
MK_NK_NM, // 5 MK_NK_NM_MN_MN, // 5
KM_KN_NM, // 6 KM_KN_NM_MN_MN, // 6
KM_NK_NM, // 7 KM_NK_NM_MN_MN, // 7
}; };
enum struct GemmDataType enum struct MatrixDataType
{ {
F32_F32_F32, // 0 F32_F32_F32_F32_F32, // 0
F16_F16_F16, // 1 F16_F16_F16_F16_F16_F16_F16, // 1
BF16_BF16_BF16, // 2 BF16_BF16_BF16_BF16_BF16, // 2
INT8_INT8_INT8, // 3 INT8_INT8_INT8_INT8_INT8, // 3
}; };
if(argc != 16) if(argc != 16)
...@@ -41,13 +41,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -41,13 +41,13 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"); printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n"); printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n"); printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC, StrideD0, StrideD1\n"); printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideD0, StrideD1, StrideE\n");
// clang-format on // clang-format on
exit(1); exit(1);
} }
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[2])); const auto data_type = static_cast<MatrixDataType>(std::stoi(argv[2]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[3])); const auto layout = static_cast<MatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]); const bool do_verification = std::stoi(argv[4]);
const int init_method = std::stoi(argv[5]); const int init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
...@@ -59,57 +59,85 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -59,57 +59,85 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
const int StrideA = std::stoi(argv[11]); const int StrideA = std::stoi(argv[11]);
const int StrideB = std::stoi(argv[12]); const int StrideB = std::stoi(argv[12]);
const int StrideC = std::stoi(argv[13]); const int StrideD0 = std::stoi(argv[13]);
const int StrideD0 = std::stoi(argv[14]); const int StrideD1 = std::stoi(argv[14]);
const int StrideD1 = std::stoi(argv[15]); const int StrideE = std::stoi(argv[15]);
using F16 = ck::half_t; using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
auto profile = auto profile = [&](auto a_type,
[&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) { auto b_type,
using ADataType = decltype(a_type); auto d0_type,
using BDataType = decltype(b_type); auto d1_type,
using CDataType = decltype(c_type); auto e_type,
using ALayout = decltype(a_layout); auto a_layout,
using BLayout = decltype(b_layout); auto b_layout,
using CLayout = decltype(c_layout); auto d0_layout,
auto d1_layout,
auto e_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using D0DataType = decltype(d0_type);
using D1DataType = decltype(d1_type);
using EDataType = decltype(e_type);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M; using ALayout = decltype(a_layout);
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K; using BLayout = decltype(b_layout);
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M; using D0Layout = decltype(d0_layout);
using D1Layout = decltype(d1_layout);
using ELayout = decltype(e_layout);
return ck::profiler:: const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
profile_gemm_gelu_impl<ADataType, BDataType, CDataType, ALayout, BLayout, CLayout>( const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
do_verification, const int DefaultStrideD0 = ck::is_same_v<D0Layout, Row> ? N : M;
init_method, const int DefaultStrideD1 = ck::is_same_v<D1Layout, Row> ? N : M;
do_log, const int DefaultStrideE = ck::is_same_v<ELayout, Row> ? N : M;
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC);
};
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) return ck::profiler::profile_gemm_add_add_gelu_impl<ADataType,
BDataType,
D0DataType,
D1DataType,
EDataType,
ALayout,
BLayout,
D0Layout,
D1Layout,
ELayout>(
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideD0 < 0) ? DefaultStrideD0 : StrideD0,
(StrideD1 < 0) ? DefaultStrideD1 : StrideD1,
(StrideE < 0) ? DefaultStrideE : StrideE);
};
if(data_type == MatrixDataType::F16_F16_F16_F16_F16 && layout == MatrixLayout::MK_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, Row{}, Row{}, Row{}); return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Row{}, Row{}, Row{}, Row{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::MK_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, Row{}, Col{}, Row{}); return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Row{}, Col{}, Row{}, Row{}, Row{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_KN_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, Col{}, Row{}, Row{}); return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Row{}, Row{}, Row{}, Row{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) else if(data_type == MatrixDataType::F16_F16_F16_F16_F16 &&
layout == MatrixLayout::KM_NK_MN_MN_MN)
{ {
return profile(F16{}, F16{}, F16{}, Col{}, Col{}, Row{}); return profile(F16{}, F16{}, F16{}, F16{}, F16{}, Col{}, Col{}, Row{}, Row{}, Row{});
} }
else else
{ {
......
...@@ -6,21 +6,21 @@ ...@@ -6,21 +6,21 @@
#include "profile_convnd_fwd.hpp" #include "profile_convnd_fwd.hpp"
int profile_gemm(int, char*[]); // int profile_gemm(int, char*[]);
int profile_gemm_bias_2d(int, char*[]); // int profile_gemm_bias_2d(int, char*[]);
int profile_gemm_bias_relu(int, char*[]); // int profile_gemm_bias_relu(int, char*[]);
int profile_gemm_bias_relu_add(int, char*[]); // int profile_gemm_bias_relu_add(int, char*[]);
int profile_gemm_reduce(int, char*[]); // int profile_gemm_reduce(int, char*[]);
int profile_batched_gemm(int, char*[]); // int profile_batched_gemm(int, char*[]);
int profile_grouped_gemm(int, char*[]); // int profile_grouped_gemm(int, char*[]);
int profile_conv_fwd(int, char*[]); // int profile_conv_fwd(int, char*[]);
int profile_conv_fwd_bias_relu(int, char*[]); // int profile_conv_fwd_bias_relu(int, char*[]);
int profile_conv_fwd_bias_relu_add(int, char*[]); // int profile_conv_fwd_bias_relu_add(int, char*[]);
int profile_conv_fwd_bias_relu_atomic_add(int, char*[]); // int profile_conv_fwd_bias_relu_atomic_add(int, char*[]);
int profile_convnd_bwd_data(int, char*[], int); // int profile_convnd_bwd_data(int, char*[], int);
int profile_reduce(int, char*[]); // int profile_reduce(int, char*[]);
int profile_conv_bwd_weight(int, char*[]); // int profile_conv_bwd_weight(int, char*[]);
int profile_batched_gemm_reduce(int, char*[]); // int profile_batched_gemm_reduce(int, char*[]);
int profile_gemm_add_add_fastgelu(int, char*[]); int profile_gemm_add_add_fastgelu(int, char*[]);
static void print_helper_message() static void print_helper_message()
...@@ -58,6 +58,7 @@ int main(int argc, char* argv[]) ...@@ -58,6 +58,7 @@ int main(int argc, char* argv[])
{ {
return profile_gemm(argc, argv); return profile_gemm(argc, argv);
} }
#if 0
else if(strcmp(argv[1], "gemm_bias_2d") == 0) else if(strcmp(argv[1], "gemm_bias_2d") == 0)
{ {
return profile_gemm_bias_2d(argc, argv); return profile_gemm_bias_2d(argc, argv);
...@@ -122,6 +123,7 @@ int main(int argc, char* argv[]) ...@@ -122,6 +123,7 @@ int main(int argc, char* argv[])
{ {
return profile_conv_bwd_weight(argc, argv); return profile_conv_bwd_weight(argc, argv);
} }
#endif
else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0) else if(strcmp(argv[1], "gemm_add_add_fastgelu") == 0)
{ {
return profile_gemm_add_add_fastgelu(argc, argv); return profile_gemm_add_add_fastgelu(argc, argv);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment