Commit 1e6272d3 authored by root's avatar root
Browse files

packed in thread_copy_v4

parent 0bd13d76
...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con ...@@ -157,7 +157,7 @@ bool run_splitK_gemm(const ProblemSize& problem_size, const ExecutionConfig& con
if(config.time_kernel) if(config.time_kernel)
{ {
float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 300, 500}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 300, 200});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -35,7 +35,7 @@ using AccDataType = F32; ...@@ -35,7 +35,7 @@ using AccDataType = F32;
using CDataType = F16; using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -55,7 +55,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu ...@@ -55,7 +55,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, ck::PipelineVersion::v2>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, true, 1, 1, S<1, 32, 1, 8>, 8, F16, ck::PipelineVersion::v2>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 16, 1, 4>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
// clang-format on // clang-format on
#include "run_splitK_gemm_example.inc" #include "run_splitK_gemm_example.inc"
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F8 = ck::f8_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
...@@ -30,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -30,12 +31,12 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F16;
using BDataType = F16; using BDataType = F8;
using AccDataType = F32; using AccDataType = F32;
using CDataType = F16; using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -47,12 +48,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -47,12 +48,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off // clang-format off
// GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x4x4_32x128x4x8 LoopScheduler: Interwave, PipelineVersion: v1 // GemmXdlSplitKCShuffle_MNKPadding_RRR_B128_Vec8x4x4_32x128x4x8 LoopScheduler: Interwave, PipelineVersion: v1
//######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 4, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 128, 32, 256, 4, 8, 32, 32, 1, 4, S<1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 16, 16, true, S<1, 4, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 16, 16, true, 1, 1, S<1, 16, 1, 4>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
// clang-format on // clang-format on
#include "run_splitK_gemm_example.inc" #include "run_splitK_gemm_example.inc"
......
...@@ -228,6 +228,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -228,6 +228,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
{ {
if(kbatch == 1) if(kbatch == 1)
{ {
#if 1
const auto kernel = const auto kernel =
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true, true,
...@@ -238,6 +239,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -238,6 +239,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CElementwiseOperation>; CElementwiseOperation>;
Run(kernel); Run(kernel);
#endif
} }
else else
{ {
...@@ -253,6 +255,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -253,6 +255,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
Run(kernel); Run(kernel);
} }
} }
#if 1
else else
{ {
if(kbatch == 1) if(kbatch == 1)
...@@ -282,6 +285,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -282,6 +285,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
Run(kernel); Run(kernel);
} }
} }
#endif
return ave_time; return ave_time;
} }
......
...@@ -1156,7 +1156,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1156,7 +1156,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}]; src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
}); });
} }
...@@ -1164,10 +1163,13 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1164,10 +1163,13 @@ struct ThreadwiseTensorSliceTransfer_v4
// DstData) // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
using dst_v_t = typename vector_type_maker_t<DstData, 2>::type;
using src_v_t = typename vector_type_maker_t<SrcData, 2>::type;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector / 2, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) = dst_tmp_vector.template AsType<dst_v_t>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]); type_convert<dst_v_t>(src_tmp_vector.template AsType<src_v_t>()[i]);
}); });
// copy data from dst_tmp_vector into dst_buf // copy data from dst_tmp_vector into dst_buf
......
...@@ -33,6 +33,22 @@ __host__ __device__ constexpr Y type_convert(X x) ...@@ -33,6 +33,22 @@ __host__ __device__ constexpr Y type_convert(X x)
return static_cast<Y>(type_convert<NonConstY, NonConstX>(x)); return static_cast<Y>(type_convert<NonConstY, NonConstX>(x));
} }
template <>
inline __host__ __device__ constexpr half2_t type_convert<half2_t, f8x2_t>(f8x2_t x)
{
uint32_t t = bit_cast<uint16_t>(x);
return bit_cast<half2_t>(t);
}
template <>
inline __host__ __device__ constexpr f8x2_t type_convert<f8x2_t, half2_t>(half2_t x)
{
uint16_t t = bit_cast<uint32_t>(x);
return bit_cast<f8x2_t>(t);
}
// convert bfp16 to fp32 // convert bfp16 to fp32
template <> template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x) inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
...@@ -129,10 +145,11 @@ template <> ...@@ -129,10 +145,11 @@ template <>
inline __host__ __device__ float type_convert<float, f8_t>(f8_t x) inline __host__ __device__ float type_convert<float, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
float fval; //float fval;
uint32_t i32val = static_cast<uint32_t>(x); //uint32_t i32val = static_cast<uint32_t>(x);
fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); //fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
// asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
float fval = bit_cast<uint8_t>(x);
return fval; return fval;
#else #else
constexpr bool negative_zero_nan = true; constexpr bool negative_zero_nan = true;
...@@ -194,12 +211,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) ...@@ -194,12 +211,14 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
{ {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
return type_convert<half_t>(type_convert<float>(x)); // return type_convert<half_t>(type_convert<float>(x));
uint32_t tmp = bit_cast<uint8_t>(x);
return static_cast<half_t>(tmp);;
#else #else
// constexpr bool negative_zero_nan = true; // constexpr bool negative_zero_nan = true;
// return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); // return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
uint16_t t = bit_cast<uint8_t>(x); uint16_t t = bit_cast<uint8_t>(x);
return bit_cast<half_t>(t); return static_cast<half_t>(t);
#endif #endif
} }
......
...@@ -201,7 +201,7 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -201,7 +201,7 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel, 0, 50, 100});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
...@@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1 ...@@ -8,7 +8,7 @@ MY_PROJECT_SOURCE=$1
cmake \ cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \ -D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_CXX_FLAGS="-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -save-temps=$PWD" \ -D CMAKE_CXX_FLAGS="-v --save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=OFF \ -D BUILD_DEV=OFF \
-D GPU_TARGETS="gfx942" \ -D GPU_TARGETS="gfx942" \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment