Commit a36ceb6d authored by aska-0096's avatar aska-0096
Browse files

intial commit

parent 55a01eef
...@@ -40,5 +40,6 @@ add_dependencies(example_gemm_xdl example_gemm_xdl_fp64) ...@@ -40,5 +40,6 @@ add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_custom_target(example_gemm_wmma) add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp) add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16) add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
...@@ -17,6 +17,7 @@ if(USE_BITINT_EXTENSION_INT4) ...@@ -17,6 +17,7 @@ if(USE_BITINT_EXTENSION_INT4)
endif() # USE_BITINT_EXTENSION_INT4 endif() # USE_BITINT_EXTENSION_INT4
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp) add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_hardtanh_wmma_int8 grouped_conv_fwd_bias_hardtanh_wmma_int8.cpp)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp) add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
......
...@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc, ...@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]); const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param( conv_param = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv); num_dim_spatial, threshold_to_catch_partial_args+1, argv);
} }
else else
{ {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
// kernel data types
using InKernelDataType = I8;
using WeiKernelDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I8;
using BiasKernelDataType = I8;
using ResidualKernelDataType = I8;
using OutKernelDataType = I8;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddHardTanhAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
...@@ -51,33 +51,33 @@ using DeviceConvFwdInstance = ...@@ -51,33 +51,33 @@ using DeviceConvFwdInstance =
OutElementOp, OutElementOp,
ConvSpec, // ConvForwardSpecialization ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
256, // BlockSize 128, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 16, // NPerBlock
4, // K0PerBlock 2, // K0PerBlock
8, // K1 16, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 2, // MRepeat
2, // NRepeat 1, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 128, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1 16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 16, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector 2, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1 2, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
1, 1,
1, 1,
S<1, 32, 1, 8>, S<1, 32, 1, 4>,
8>; 4>;
template <ck::index_t NDimSpatial> template <ck::index_t NDimSpatial>
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial, using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
......
...@@ -143,6 +143,26 @@ struct AddHardswishAdd ...@@ -143,6 +143,26 @@ struct AddHardswishAdd
} }
}; };
struct AddHardTanhAdd
{
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, int8_t, int8_t>(int8_t& y,
const int8_t& x0,
const int8_t& x1,
const int8_t& x2) const
{
int32_t a = x0 + x1;
int32_t b = a;
if(b>1) b = 1;
else if(b<-1) b = -1;
int32_t c = b + x2;
y = c;
}
};
// C = A * B // C = A * B
// E = C + D0 + D1 // E = C + D0 + D1
struct AddAdd struct AddAdd
......
...@@ -673,7 +673,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle ...@@ -673,7 +673,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
......
...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma ...@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK); constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm = auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize, BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA, FloatA,
FloatB, FloatB,
FloatAcc, FloatAcc,
......
...@@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8, ...@@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
template <index_t MPerWmma, template <index_t MPerWmma,
index_t NPerWmma, index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA, class FloatA,
class FloatB, class FloatB,
class FloatC> class FloatC,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{ {
if constexpr(wave_size == 32) if constexpr(wave_size == 32)
......
...@@ -361,5 +361,14 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, f ...@@ -361,5 +361,14 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, f
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c)); asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
} }
// __device__ void amd_assembly_wmma_f32_16x16x16_iu8_w32(bool neg_a, int8x16_t a,
// bool neg_b, int8x16_t b,
// int32x8_t& c, bool clamp)
// {
// asm volatile("v_wmma_f32_16x16x16_iu8 %0, %1, %2, %0 neg_lo:[%3, %4, %5]"
// : "=v"(c)
// : "v"(a), "v"(b), "0"(c), ""(neg_a), ""(neg_b), ""(clamp));
// }
} // namespace ck } // namespace ck
#endif #endif
...@@ -102,6 +102,9 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp> ...@@ -102,6 +102,9 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b), bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}], reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp); clamp);
// amd_assembly_wmma_f32_16x16x16_iu8_w32(
// neg_a, reg_a, neg_b, reg_b, reg_c.template AsType<int32x8_t>()(Number<0>{}), clamp);
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment