Commit b7419eec authored by Jing Zhang's avatar Jing Zhang
Browse files

seperate float a/b

parent 57d0ea67
......@@ -6,12 +6,12 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp"
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using BDataType = int8_t;
using CDataType = ck::half_t;
using AccDataType = float;
using ALayout = Row;
using BLayout = Col;
using BLayout = Row;
using CLayout = Row;
using AElementOp = PassThrough;
......@@ -23,12 +23,11 @@ static constexpr auto GemmMNPadding = ck::tensor_operation::device::GemmSpeciali
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmDl
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| CThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BThreadTransfer| BThreadTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcAccess| SrcVectorDim| SrcScalarPerVector| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| Order| | | Order| | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 32, 2, 1, 1, 1, S<1, 1, 1, 2>, S<32, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 2, S<0, 1, 2, 3, 4, 5>, 5, 1>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 64, 16, 4, 1, 1, 1, S<1, 1, 1, 4>, S<16, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<1, 2, 0, 3>, 3, 4, S<0, 1, 2, 3, 4, 5>, 5, 1>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmMNPadding, 64, 1, 512, 2, 4, 1, 8, 1, S<1, 1, 1, 4>, S<2, 1, 1, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, S<1, 1, 1, 4>, S<1, 2, 0, 3>, 2, 8, S<0, 1, 2, 3, 4, 5>, 5, 1>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::
......
......@@ -80,6 +80,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
using GridwiseGemm =
GridwiseGemmDl_km_kn_mn_v1r3<BlockSize,
ADataType,
BDataType,
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
......@@ -184,10 +185,17 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
float ave_time = 0;
using ComputeType = ADataType;
if(has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, true>;
const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
BDataType,
ComputeType,
CDataType,
true,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -206,8 +214,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, true, false>;
const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
BDataType,
ComputeType,
CDataType,
true,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -226,8 +239,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, true>;
const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
BDataType,
ComputeType,
CDataType,
false,
true>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......@@ -246,8 +264,13 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
else
{
const auto kernel =
kernel_gemm_dl_v1r3<GridwiseGemm, ADataType, CDataType, false, false>;
const auto kernel = kernel_gemm_dl_v1r3<GridwiseGemm,
ADataType,
BDataType,
ComputeType,
CDataType,
false,
false>;
ave_time = launch_and_time_kernel(stream_config,
kernel,
......
......@@ -19,7 +19,9 @@
namespace ck {
template <typename GridwiseGemm,
typename FloatAB,
typename FloatA,
typename FloatB,
typename ComputeType,
typename FloatC,
bool HasMainKBlockLoop,
bool HasDoubleTailKBlockLoop>
......@@ -27,8 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_dl_v1r3(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
kernel_gemm_dl_v1r3(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
const index_t M,
const index_t N,
......@@ -38,9 +40,9 @@ __global__ void
const index_t StrideC)
{
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ComputeType);
__shared__ FloatAB p_shared_block[shared_block_size];
__shared__ ComputeType p_shared_block[shared_block_size];
const auto a_grid_desc_k0_m_k1 = GridwiseGemm::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
const auto b_grid_desc_k0_n_k1 = GridwiseGemm::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
......@@ -68,7 +70,8 @@ __global__ void
}
template <index_t BlockSize,
typename FloatAB,
typename FloatA,
typename FloatB,
typename FloatAcc,
typename FloatC,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
......@@ -121,7 +124,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size =
math::integer_least_multiple(a_block_desc_k_m.GetElementSpaceSize(), max_lds_align);
return 2 * (a_block_aligned_space_size) * sizeof(FloatAB);
return 2 * (a_block_aligned_space_size) * sizeof(ComputeType);
}
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
......@@ -368,12 +371,14 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
decltype(MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{}));
using Block2CTileMap = decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}));
using ComputeType = FloatA;
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__device__ static void
Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
Run(const FloatA* __restrict__ p_a_grid,
const FloatB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
FloatAB* __restrict__ p_shared_block,
ComputeType* __restrict__ p_shared_block,
const AGridDesc_K0_M0_M1_K1& a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1& b_grid_desc_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11& c_grid_desc_m0_m10_m11_n0_n10_n11,
......@@ -423,6 +428,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
"wrong!");
ignore = a_global_buf;
// A matrix blockwise copy
auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v5r1<
BlockSize,
......@@ -431,8 +438,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
ABlockTransferThreadClusterArrangeOrder,
FloatAB,
FloatAB,
FloatA,
ComputeType,
remove_reference_t<decltype(a_grid_desc_k0_m0_m1_k1)>,
decltype(a_block_desc_k0_m0_m1_k1),
ABlockTransferSrcAccessOrder,
......@@ -451,8 +458,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
make_tuple(Number<K0PerBlock>{}, I1, Number<NPerThread>{}, Number<K1>{}));
auto b_threadwise_copy =
ThreadwiseTensorSliceTransfer_v2<FloatAB,
FloatAB,
ThreadwiseTensorSliceTransfer_v2<FloatB,
ComputeType,
remove_reference_t<decltype(b_grid_desc_k0_n0_n1_k1)>,
decltype(b_thread_desc_k0_n0_n1_k1),
Sequence<K0PerBlock, 1, NPerThread, K1.value>,
......@@ -470,8 +477,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
const auto blockwise_gemm =
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
FloatAB,
FloatAB,
ComputeType,
ComputeType,
FloatAcc,
decltype(a_k0_m_k1_block_desc),
decltype(b_k0_n_k1_thread_desc),
......@@ -489,12 +496,12 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
a_block_desc_k0_m0_m1_k1.GetElementSpaceSize(), max_lds_align);
FloatAB* p_a_block_double = p_shared_block;
ComputeType* p_a_block_double = p_shared_block;
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto b_thread_odd_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeType>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
auto b_thread_even_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeType>(
b_k0_n_k1_thread_desc.GetElementSpaceSize());
// register allocation for output
......@@ -516,8 +523,8 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// LDS double buffer: preload data into LDS
{
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
// a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
// a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
......@@ -544,7 +551,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
// a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
......@@ -558,7 +565,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
// a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
// odd iteration
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_k0_m0_m1_k1,
......@@ -568,7 +575,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
b_thread_slice_copy_step);
// LDS doubel buffer: load next data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
// a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
......@@ -582,7 +589,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_odd_buf, b_thread_odd_buf, c_thread_buf);
// LDS double buffer: store next data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
// a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_even_buf);
k_block_data_begin += 2 * K0PerBlock;
} while(k_block_data_begin < K0 - 2 * K0PerBlock);
......@@ -598,7 +605,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
block_sync_lds();
// LDS double buffer: load last data from device mem
a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
// a_blockwise_copy.RunRead(a_grid_desc_k0_m0_m1_k1, a_global_buf);
b_threadwise_copy.Run(b_grid_desc_k0_n0_n1_k1,
b_global_buf,
......@@ -610,7 +617,7 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
blockwise_gemm.Run(a_block_even_buf, b_thread_even_buf, c_thread_buf);
// LDS double buffer: store last data to LDS
a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
// a_blockwise_copy.RunWrite(a_block_desc_k0_m0_m1_k1, a_block_odd_buf);
block_sync_lds();
......
......@@ -12,7 +12,8 @@ cmake
-save-temps=$PWD" \
-D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \
-D GPU_TARGETS="gfx1100" \
-D GPU_TARGETS="gfx90a" \
-D DL_KERNEL=ON \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment