Commit a36ceb6d authored by aska-0096's avatar aska-0096
Browse files

intial commit

parent 55a01eef
......@@ -40,5 +40,6 @@ add_dependencies(example_gemm_xdl example_gemm_xdl_fp64)
add_custom_target(example_gemm_wmma)
add_example_executable(example_gemm_wmma_fp16 gemm_wmma_fp16.cpp)
add_example_executable(example_gemm_wmma_int8 gemm_wmma_int8.cpp)
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp"
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int8_t;
using AccDataType = int32_t;
using CShuffleDataType = int8_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 128, 128, 8, 16, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 1, 1, S<1, 32, 1, 8>, 8, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
#include "run_gemm_example.inc"
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
......@@ -17,6 +17,7 @@ if(USE_BITINT_EXTENSION_INT4)
endif() # USE_BITINT_EXTENSION_INT4
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_hardtanh_wmma_int8 grouped_conv_fwd_bias_hardtanh_wmma_int8.cpp)
add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
......
......@@ -137,7 +137,7 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args, argv);
num_dim_spatial, threshold_to_catch_partial_args+1, argv);
}
else
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
// kernel data types
using InKernelDataType = I8;
using WeiKernelDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I8;
using BiasKernelDataType = I8;
using ResidualKernelDataType = I8;
using OutKernelDataType = I8;
// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddHardTanhAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
......@@ -51,33 +51,33 @@ using DeviceConvFwdInstance =
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
8, // K1
16, // NPerBlock
2, // K0PerBlock
16, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // MRepeat
1, // NRepeat
S<1, 128, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 16, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
2, // BBlockTransferSrcScalarPerVector
2, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8>;
S<1, 32, 1, 4>,
4>;
template <ck::index_t NDimSpatial>
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
......
......@@ -143,6 +143,26 @@ struct AddHardswishAdd
}
};
struct AddHardTanhAdd
{
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
template <>
__host__ __device__ constexpr void operator()<int8_t, int8_t, int8_t, int8_t>(int8_t& y,
const int8_t& x0,
const int8_t& x1,
const int8_t& x2) const
{
int32_t a = x0 + x1;
int32_t b = a;
if(b>1) b = 1;
else if(b<-1) b = -1;
int32_t c = b + x2;
y = c;
}
};
// C = A * B
// E = C + D0 + D1
struct AddAdd
......
......@@ -673,7 +673,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
ADataType,
BDataType,
AccDataType,
......
......@@ -414,7 +414,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr auto KPack = math::integer_least_multiple(K1, WmmaK);
auto blockwise_gemm =
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO<BlockSize,
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle<BlockSize,
FloatA,
FloatB,
FloatAcc,
......
......@@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
class FloatC,
bool neg_a = true,
bool neg_b = true,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
......
......@@ -361,5 +361,14 @@ __device__ void amd_assembly_wmma_f32_16x16x16_f16_w32(half16_t a, half16_t b, f
asm volatile("v_wmma_f32_16x16x16_f16 %0, %1, %2, %0" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
// __device__ void amd_assembly_wmma_f32_16x16x16_iu8_w32(bool neg_a, int8x16_t a,
// bool neg_b, int8x16_t b,
// int32x8_t& c, bool clamp)
// {
// asm volatile("v_wmma_f32_16x16x16_iu8 %0, %1, %2, %0 neg_lo:[%3, %4, %5]"
// : "=v"(c)
// : "v"(a), "v"(b), "0"(c), ""(neg_a), ""(neg_b), ""(clamp));
// }
} // namespace ck
#endif
......@@ -102,6 +102,9 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
bit_cast<int32x4_t>(reg_b),
reg_c.template AsType<int32x8_t>()[Number<0>{}],
clamp);
// amd_assembly_wmma_f32_16x16x16_iu8_w32(
// neg_a, reg_a, neg_b, reg_b, reg_c.template AsType<int32x8_t>()(Number<0>{}), clamp);
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment