Commit 0f3b88bf authored by Jing Zhang's avatar Jing Zhang
Browse files

add a prototype of int4

parent cfac9497
...@@ -543,7 +543,7 @@ ENDIF() ...@@ -543,7 +543,7 @@ ENDIF()
ENDFOREACH() ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES}) add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
add_subdirectory(library) #add_subdirectory(library)
if(NOT GPU_ARCHS) if(NOT GPU_ARCHS)
rocm_package_setup_component(tests rocm_package_setup_component(tests
...@@ -556,20 +556,20 @@ if(NOT GPU_ARCHS) ...@@ -556,20 +556,20 @@ if(NOT GPU_ARCHS)
PACKAGE_NAME examples PACKAGE_NAME examples
) )
add_subdirectory(example) add_subdirectory(example)
if(BUILD_TESTING) #if(BUILD_TESTING)
add_subdirectory(test) #add_subdirectory(test)
endif() #endif()
endif() endif()
rocm_package_setup_component(profiler rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler PACKAGE_NAME ckprofiler
) )
add_subdirectory(profiler) #add_subdirectory(profiler)
if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS)) #if(CK_USE_CODEGEN AND (GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen) #add_subdirectory(codegen)
endif() #endif()
#Create an interface target for the include only files and call it "composablekernels" #Create an interface target for the include only files and call it "composablekernels"
include(CMakePackageConfigHelpers) include(CMakePackageConfigHelpers)
......
...@@ -29,6 +29,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3) ...@@ -29,6 +29,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_v3)
add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp8_v3 gemm_xdl_fp8_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_v3)
add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp) add_example_executable(example_gemm_xdl_fp16_fp8_v3 gemm_xdl_fp16_fp8_v3.cpp)
add_example_executable(example_gemm_xdl_fp16_pk_i4_v3 gemm_xdl_fp16_pk_i4_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp) add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3) add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
......
...@@ -5,8 +5,8 @@ ...@@ -5,8 +5,8 @@
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::f8_t; using ADataType = ck::half_t;
using BDataType = ck::half_t; using BDataType = ck::f8_t;
using AccDataType = float; using AccDataType = float;
using CShuffleDataType = ck::half_t; using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
...@@ -27,17 +27,31 @@ using DeviceGemmV2Instance = ...@@ -27,17 +27,31 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
64, 64,
16, 16, 16, 16,
64, 16, 8, 256, 8, 16,
16, 16, 16, 16,
1, 1, 1, 1,
S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 4>, 4, 1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v1>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#else
128,
16, 32,
128, 8, 16,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3.hpp"
using ADataType = ck::half_t;
using BDataType = ck::pk_i4_t;
using AccDataType = float;
using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
inline __host__ __device__ ck::half2_t
type_convert_packed_i4_to_half2(ck::pk_i4_t x)
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f);
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
return {l_f16, h_f16};
}
struct ElementwisePackedI4ToHalf2
{
__host__ __device__ void
operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
y = type_convert_packed_i4_to_half2(x);
}
constexpr const static bool is_pack2_invocable = true;
};
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
#if 0
64,
16, 16,
256, 8, 32,
16, 16,
1, 1,
S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
#else
128,
16, 32,
128, 8, 32,
16, 16,
1, 1,
S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0,
S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 16, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1>;
#endif
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
CDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
#include "run_gemm_example_v2.inc"
int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); }
...@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t; ...@@ -12,7 +12,7 @@ using CShuffleDataType = ck::half_t;
using CDataType = ck::half_t; using CDataType = ck::half_t;
using ALayout = Row; using ALayout = Row;
using BLayout = Row; using BLayout = Col;
using CLayout = Row; using CLayout = Row;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -27,17 +27,17 @@ using DeviceGemmV2Instance = ...@@ -27,17 +27,17 @@ using DeviceGemmV2Instance =
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
PassThrough, PassThrough, PassThrough, GemmDefault, PassThrough, PassThrough, PassThrough, GemmDefault,
256, 64,
224, 256, 16, 16,
64, 8, 2, 256, 8, 8,
16, 16, 16, 16,
7, 8, 1, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 8, 8, 0, 2, 8, 8, 0,
S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, S<32, 2, 1>, S<1, 0, 2>, S<1, 0, 2>,
1, 8, 2, 0, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, 8, 1, 1, S<1, 16, 1, 4>, 4,
ck::BlockGemmPipelineScheduler::Intrawave,ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v2>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
......
...@@ -228,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -228,6 +228,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
} }
bool pass = true; bool pass = true;
#if 0
if(config.do_verification) if(config.do_verification)
{ {
auto ref_gemm = ReferenceGemmInstance{}; auto ref_gemm = ReferenceGemmInstance{};
...@@ -257,11 +258,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -257,11 +258,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
get_atol<CDataType>()); get_atol<CDataType>());
#endif #endif
} }
#endif
if(config.time_kernel) if(config.time_kernel)
{ {
ave_time = ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 5, 10, true, 4}); invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -22,6 +22,19 @@ struct PassThroughPack2 ...@@ -22,6 +22,19 @@ struct PassThroughPack2
auto t = type_convert<float2_t>(x); auto t = type_convert<float2_t>(x);
y = type_convert<half2_t>(t); y = type_convert<half2_t>(t);
} }
__host__ __device__ constexpr void operator()(ck::half2_t& y, const ck::pk_i4_t& x) const
{
uint8_t x_u8 = ck::bit_cast<uint8_t>(x);
uint8_t x_l = (x_u8 & 0x0f) >> 0;
uint8_t x_h = (x_u8 & 0xf0) >> 4;
auto l_f16 = ck::type_convert<ck::half_t>(x_l);
auto h_f16 = ck::type_convert<ck::half_t>(x_h);
y = {l_f16, h_f16};
}
constexpr const static bool is_pack2_invocable = true; constexpr const static bool is_pack2_invocable = true;
}; };
......
...@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1007,6 +1007,13 @@ struct ThreadwiseTensorSliceTransfer_v4
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx) __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
: src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx)) : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
{ {
...@@ -1015,6 +1022,8 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1015,6 +1022,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1)), "pk data N cannot be 1");
} }
template <typename SrcRefToOriginDisplacement, template <typename SrcRefToOriginDisplacement,
...@@ -1109,7 +1118,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1109,7 +1118,7 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step); move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector; vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type; using src_vector_t = typename decltype(src_tmp_vector)::type;
...@@ -1120,7 +1129,7 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1120,7 +1129,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if constexpr(SrcBuffer::IsDynamicBuffer()) if constexpr(SrcBuffer::IsDynamicBuffer())
{ {
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) = src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid); src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize, is_src_valid);
} }
else if constexpr(SrcBuffer::IsStaticBuffer()) else if constexpr(SrcBuffer::IsStaticBuffer())
{ {
...@@ -1129,11 +1138,34 @@ struct ThreadwiseTensorSliceTransfer_v4 ...@@ -1129,11 +1138,34 @@ struct ThreadwiseTensorSliceTransfer_v4
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
i * src_scalar_step_in_vector); i * src_scalar_step_in_vector);
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}]; src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset / PackedSize>{}];
}); });
} }
if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value && if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value)
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
using dst_v_t = typename vector_type_maker_t<DstData, PackedSize>::type;
using src_v_t = typename vector_type_maker_t<SrcData, 1>::type;
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack2{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
}
else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
is_same<remove_cvref_t<DstData>, half_t>::value && is_same<remove_cvref_t<DstData>, half_t>::value &&
SrcScalarPerVector % 2 == 0) SrcScalarPerVector % 2 == 0)
{ {
......
...@@ -31,8 +31,8 @@ template <typename SliceLengths, ...@@ -31,8 +31,8 @@ template <typename SliceLengths,
typename DstDimAccessOrder, typename DstDimAccessOrder,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t DstVectorDim, index_t DstVectorDim,
index_t SrcScalarPerVector, index_t SrcScalarPerVector_,
index_t DstScalarPerVector, index_t DstScalarPerVector_,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
...@@ -55,6 +55,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -55,6 +55,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr auto SrcScalarPerVector = Number<SrcScalarPerVector_ / PackedSize>{};
static constexpr auto DstScalarPerVector = Number<DstScalarPerVector_ / PackedSize>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v3r1( __device__ constexpr ThreadwiseTensorSliceTransfer_v3r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_slice_origin, const Index& src_slice_origin,
...@@ -67,6 +78,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -67,6 +78,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op), src_element_op_(src_element_op),
dst_element_op_(dst_element_op) dst_element_op_(dst_element_op)
{ {
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>, "SrcData != DstData");
static_assert(!(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> && (SrcScalarPerVector == 1 || DstScalarPerVector == 1)), "pk data N cannot be 1");
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -95,11 +108,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -95,11 +108,11 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// scalar per access on each dim // scalar per access on each dim
// TODO: don't use lambda_scalar_per_access // TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector * PackedSize>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(SrcVectorDim) % (SrcScalarPerVector * PackedSize) == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector"); "SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{}; constexpr auto src_dim_access_order = SrcDimAccessOrder{};
...@@ -181,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -181,7 +194,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
auto src_vector_container = auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)}; src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, true)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -279,7 +292,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -279,7 +292,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// OOB Check // OOB Check
constexpr auto src_scalar_per_access = generate_sequence( constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector * PackedSize>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access; constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
...@@ -368,9 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -368,9 +381,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto scalar_per_access = generate_sequence( constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim, detail::lambda_scalar_per_access_for_src_and_dst<SrcVectorDim,
SrcScalarPerVector, SrcScalarPerVector * PackedSize,
DstVectorDim, DstVectorDim,
DstScalarPerVector>{}, DstScalarPerVector * PackedSize>{},
Number<nDim>{}); Number<nDim>{});
constexpr auto access_lengths = SliceLengths{} / scalar_per_access; constexpr auto access_lengths = SliceLengths{} / scalar_per_access;
...@@ -410,7 +423,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -410,7 +423,12 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
else else
{ {
static_ford<SliceLengths>{}([&](auto idx) { constexpr auto packed_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, PackedSize>{}, Number<nDim>{});
constexpr auto packed_access_lengths = SliceLengths{} / packed_per_access;
static_ford<decltype(packed_access_lengths)>{}([&](auto idx) {
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
}); });
} }
...@@ -438,7 +456,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -438,7 +456,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// src scalar per access on each dim // src scalar per access on each dim
// TODO: don't use this // TODO: don't use this
constexpr auto dst_scalar_per_access = generate_sequence( constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{}); detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector * PackedSize>{}, Number<nDim>{});
constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access; constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
...@@ -532,7 +550,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -532,7 +550,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// copy data from dst_vector_container to dst_buf // copy data from dst_vector_container to dst_buf
dst_buf.template Set<dst_vector_t>( dst_buf.template Set<dst_vector_t>(
dst_coord_.GetOffset(), dst_coord_.GetOffset() / PackedSize,
is_dst_valid, is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
......
...@@ -429,7 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w ...@@ -429,7 +429,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, f8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || (is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), (is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented"); "wrong! not implemented");
using r_t = typename vector_type<T, N>::type; using r_t = typename vector_type<T, N>::type;
......
...@@ -12,6 +12,7 @@ using half_t = _Float16; ...@@ -12,6 +12,7 @@ using half_t = _Float16;
using int4_t = _BitInt(4); using int4_t = _BitInt(4);
using f8_t = _BitInt(8); using f8_t = _BitInt(8);
using bf8_t = unsigned _BitInt(8); using bf8_t = unsigned _BitInt(8);
using pk_i4_t = unsigned char;
// vector_type // vector_type
template <typename T, index_t N> template <typename T, index_t N>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment