Commit e61489e2 authored by Jing Zhang's avatar Jing Zhang
Browse files

fixed

parent 0fd1d636
...@@ -30,20 +30,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; ...@@ -30,20 +30,20 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ADataType = F16; using ADataType = F8;
using BDataType = F8; using BDataType = F8;
using AccDataType = F32; using AccDataType = F32;
using CDataType = F16; using CDataType = F16;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Row;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShuffle
// clang-format off // clang-format off
...@@ -51,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu ...@@ -51,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| //######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>; //< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 16, 256, 4, 8, 16, 16, 1, 4, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, true, 1, 1, S<1, 16, 1, 8>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
//< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 16, 512, 4, 8, 16, 16, 1, 8, S<1, 4, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 8, 8, true, 1, 2, S<1, 16, 1, 8>, 8, F16, ck::PipelineVersion::v1, ck::LoopScheduler::Interwave>;
// clang-format on // clang-format on
#include "run_splitK_gemm_example.inc" #include "run_splitK_gemm_example.inc"
......
...@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -401,7 +401,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock().GetElementSpaceSize();
return math::max(a_block_space_size * sizeof(FloatA) + b_block_space_size * sizeof(FloatB), return math::max(a_block_space_size * sizeof(FloatA) + b_block_space_size * sizeof(FloatB),
c_block_size * sizeof(FloatC)); c_block_size * sizeof(FloatAcc));
} }
__host__ __device__ static constexpr bool CheckValidity(const Argument& karg) __host__ __device__ static constexpr bool CheckValidity(const Argument& karg)
...@@ -834,8 +834,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -834,8 +834,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr auto a_block_space_size = constexpr auto a_block_space_size =
math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align);
auto* p_a_block = static_cast<FloatA*>(p_shared_block); FloatA* p_a_block = static_cast<FloatA*>(p_shared_block);
auto* p_b_block = static_cast<FloatB*>(p_shared_block) + a_block_space_size; FloatB* p_b_block = static_cast<FloatB*>(p_shared_block) + a_block_space_size;
constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto a_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0); constexpr auto b_block_slice_copy_step = make_multi_index(0, K0PerBlock, 0, 0);
...@@ -892,7 +892,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -892,7 +892,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<FloatC*>(p_shared_block), static_cast<FloatAcc*>(p_shared_block),
c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); c_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
...@@ -943,7 +943,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -943,7 +943,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
// VGPR to LDS // VGPR to LDS
auto c_thread_copy_vgpr_to_lds = auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<FloatAcc, ThreadwiseTensorSliceTransfer_v1r3<FloatAcc,
FloatC, FloatAcc,
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -983,7 +983,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 ...@@ -983,7 +983,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths, CShuffleNRepeatPerShuffle * NWave * NPerXDL>, // BlockSliceLengths,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
FloatC, // typename SrcData, FloatAcc, // typename SrcData,
FloatC, // typename DstData, FloatC, // typename DstData,
decltype(c_block_desc_mblock_mperblock_nblock_nperblock), decltype(c_block_desc_mblock_mperblock_nblock_nperblock),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
......
...@@ -1160,8 +1160,6 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1160,8 +1160,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}]; src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
}); });
} }
#if 0
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData) // DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector; vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
...@@ -1171,15 +1169,13 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1171,15 +1169,13 @@ struct ThreadwiseTensorSliceTransfer_v4
dst_tmp_vector.template AsType<DstData>()(i) = dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]); type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
}); });
#endif
// copy data from dst_tmp_vector into dst_buf // copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) { static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset( constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector); dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
}); });
}); });
} }
......
...@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( ...@@ -278,7 +278,7 @@ __device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32(
index_t soffset, index_t soffset,
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32");
// buffer atomic-add fp32 // buffer atomic-max fp64
__device__ double __device__ double
llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata, llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int32x4_t rsrc, // dst_wave_buffer_resource int32x4_t rsrc, // dst_wave_buffer_resource
...@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -420,10 +420,124 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
using r_t = typename vector_type<T, N>::type; if constexpr(is_same<T, float>::value) // fp32
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>( {
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset); if constexpr(N == 1)
return bit_cast<r_t>(raw_data); {
return llvm_amdgcn_raw_buffer_load_fp32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp32x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp;
tmp.AsType<float4_t>()(Number<0>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
tmp.AsType<float4_t>()(Number<1>{}) =
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
return tmp.AsType<float8_t>()(Number<0>{});
}
}
else if constexpr(is_same<T, half_t>::value) // fp16
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_fp16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_fp16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
// use fp32 load to mimic fp16 load
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<half8_t>(tmp);
}
}
else if constexpr(is_same<T, bhalf_t>::value) // bf16
{
if constexpr(N == 1)
{
return llvm_amdgcn_raw_buffer_load_i16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
return llvm_amdgcn_raw_buffer_load_i16x2(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
return llvm_amdgcn_raw_buffer_load_i16x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<bhalf8_t>(tmp);
}
}
else // other datatype
{
using r_t = typename vector_type<T, N>::type;
auto raw_data = amd_buffer_load_impl_raw<sizeof(T) * N, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset);
return bit_cast<r_t>(raw_data);
}
} }
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
...@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src ...@@ -542,12 +656,150 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type; if constexpr(is_same<T, float>::value) // fp32
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
}
else if constexpr(is_same<T, half_t>::value) // fp16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
#if 0
vector_type<half_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(half_t),
static_cast<index_t>(coherence));
#else
llvm_amdgcn_raw_buffer_store_fp32x4(bit_cast<float4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
#endif
}
}
else if constexpr(is_same<T, bhalf_t>::value) // bf16
{
if constexpr(N == 1)
{
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
llvm_amdgcn_raw_buffer_store_i16x2(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
llvm_amdgcn_raw_buffer_store_i16x4(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<bhalf_t, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_i16x4(tmp.AsType<bhalf4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(bhalf_t),
static_cast<index_t>(coherence));
}
}
else
{
using r_t = typename vector_type<int8_t, sizeof(T) * N>::type;
amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data), amd_buffer_store_impl_raw<sizeof(T) * N, coherence>(bit_cast<r_t>(src_thread_data),
dst_wave_buffer_resource, dst_wave_buffer_resource,
dst_thread_addr_offset, dst_thread_addr_offset,
dst_wave_addr_offset); dst_wave_addr_offset);
}
} }
template <typename T, index_t N> template <typename T, index_t N>
......
...@@ -366,10 +366,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x) ...@@ -366,10 +366,12 @@ inline __host__ __device__ half_t type_convert<half_t, f8_t>(f8_t x)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// use native conversion to float and convert to fp16 // use native conversion to float and convert to fp16
// return type_convert<half_t>(type_convert<float>(x)); // return type_convert<half_t>(type_convert<float>(x));
return static_cast<half_t>(x); // return static_cast<half_t>(x);
return static_cast<half_t>(bit_cast<int8_t>(x));
#else #else
constexpr bool negative_zero_nan = true; // constexpr bool negative_zero_nan = true;
return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x); // return utils::cast_from_f8<f8_t, half_t, negative_zero_nan>(x);
return static_cast<half_t>(bit_cast<int8_t>(x));
#endif #endif
} }
......
set(GEMM_SPLITK_INSTANCES) set(GEMM_SPLITK_INSTANCES)
list(APPEND GEMM_SPLITK_INSTANCES device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp list(APPEND GEMM_SPLITK_INSTANCES
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp #device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp #device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp #device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp #device_gemm_xdl_splitk_fp8_f16_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp #device_gemm_xdl_splitk_fp8_f16_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp #device_gemm_xdl_splitk_fp8_f16_f16_km_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp #device_gemm_xdl_splitk_fp8_f16_f16_km_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_kn_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_mk_nk_mn_instance.cpp
device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp device_gemm_xdl_splitk_f16_fp8_f16_km_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp #device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_kn_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp #device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_mk_nk_mn_instance.cpp
#device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp #device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_kn_mn_instance.cpp
) #device_gemm_xdl_splitk_f16_f16_f16_comp_fp8_km_nk_mn_instance.cpp
)
add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES}) add_instance_library(device_gemm_splitk_instance ${GEMM_SPLITK_INSTANCES})
...@@ -236,8 +236,8 @@ int profile_gemm_impl(int do_verification, ...@@ -236,8 +236,8 @@ int profile_gemm_impl(int do_verification,
{ {
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float avg_time = invoker_ptr->Run(argument_ptr.get(), float avg_time =
StreamConfig{nullptr, time_kernel, 0, 50, 200}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
...@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -200,8 +200,8 @@ bool profile_gemm_splitk_impl(int do_verification,
std::string op_name = op_ptr->GetTypeString(); std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(), float ave_time =
StreamConfig{nullptr, time_kernel, 0, 50, 200}); invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
......
...@@ -122,96 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[]) ...@@ -122,96 +122,91 @@ int profile_gemm_splitk(int argc, char* argv[])
return pass ? 0 : 1; return pass ? 0 : 1;
}; };
#if 0
if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{}); return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Row{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{}); return profile(F32{}, F32{}, F32{}, F32{}, Row{}, Col{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{}); return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Row{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN) if(data_type == GemmDataType::F32_F32_F32 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{}); return profile(F32{}, F32{}, F32{}, F32{}, Col{}, Col{}, Row{}, F32{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) #endif
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{}); return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{}); return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{}); return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
} }
#if defined CK_ENABLE_FP8 #if defined CK_ENABLE_FP8
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN) // if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ //{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{}); // return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} //}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN) // if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ //{
return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{}); // return profile(F8{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} //}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN) // if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ //{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{}); // return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} //}
else if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN) // if(data_type == GemmDataType::F8_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ //{
return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{}); // return profile(F8{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
} //}
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{}); return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN) if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{}); return profile(F16{}, F8{}, F32{}, F16{}, Row{}, Col{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN) if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Row{}, Row{}, F16{});
} }
else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN) if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{ {
return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{}); return profile(F16{}, F8{}, F32{}, F16{}, Col{}, Col{}, Row{}, F16{});
} }
#if 0 // if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_KN_MN) //{
{ // return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}, F8{}); //}
} // if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::MK_NK_MN) //{
{ // return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}, F8{}); //}
} // if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_KN_MN) //{
{ // return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Row{}, Row{}, F8{}); //}
} // if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN)
else if(data_type == GemmDataType::F16_F16_F16_F8 && layout == GemmMatrixLayout::KM_NK_MN) //{
{ // return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{});
return profile(F16{}, F16{}, F32{}, F16{}, Col{}, Col{}, Row{}, F8{}); //}
}
#endif
#endif #endif
else return 0;
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_splitk); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_splitk);
...@@ -11,7 +11,7 @@ cmake ...@@ -11,7 +11,7 @@ cmake
-D CMAKE_CXX_FLAGS="--save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ -D CMAKE_CXX_FLAGS="--save-temps -std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D BUILD_DEV=ON \ -D BUILD_DEV=ON \
-D GPU_TARGETS="gfx942" \ -D GPU_TARGETS="gfx90a" \
-D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \ -D CMAKE_VERBOSE_MAKEFILE:BOOL=ON \
-D USE_BITINT_EXTENSION_INT4=OFF \ -D USE_BITINT_EXTENSION_INT4=OFF \
${MY_PROJECT_SOURCE} ${MY_PROJECT_SOURCE}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment