"vscode:/vscode.git/clone" did not exist on "ba26d8ceade2e235d13da06829c8eab8125e2527"
Commit 809a0c97 authored by mtgu0705's avatar mtgu0705
Browse files

fp8xint4 bpreshuffle function pass

parent 8c0e03ba
...@@ -58,7 +58,7 @@ using CElementOp = PassThrough; ...@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false; static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true; static constexpr bool PermuteB = false;
static constexpr ck::index_t KPerBlock = 128; static constexpr ck::index_t KPerBlock = 128;
// clang-format off // clang-format off
...@@ -131,7 +131,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -131,7 +131,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method) switch(config.init_method)
{ {
...@@ -161,51 +160,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -161,51 +160,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_preshuffled:" << b_k_n_preshuffled.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
// std::cout << "a_m_K size: " << sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "BDataType size: " << sizeof(BDataType) << std::endl;
// std::cout << "b_k_n size: " << sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "c_m_n size: " << sizeof(CDataType) * c_m_n_host_result.mDesc.GetElementSpaceSize()
// << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize()); DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// do GEMM // do GEMM
auto gemm = DeviceGemmV2Instance{}; auto gemm = DeviceGemmV2Instance{};
int NperXdl = gemm.GetPreShuffleParameters(); // weight pre-shuffle
preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl); int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8
int NLane = gemm.GetPreShuffleParameters();
// weight permute int KLane = 64 / NLane;
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// int K0, N, K1 int K0 = K / (KLane * KPack);
for(int j = 0; j < K0; j++) // K -> K0 KLane KPack
{ // N -> N0 NLane
for(int i = 0; i < N; i++) // N, K -> N0 K0 KLane NLane KPack
{ int tempk;
for(int jj = 0; jj < K1; jj++) for(int n=0;n<N;++n)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n_preshuffled(i * K + (j * K1 + jj));
}
}
}
}
else
{
for(int i = 0; i < N; i++)
{ {
for(int j = 0; j < K; j++) for(int k=0;k<K;++k)
{ {
b_k_n_permute(i * K + j) = b_k_n_preshuffled(i * K + j); int n0 = n / NLane;
} int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k);
} }
} }
...@@ -218,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -218,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++) for(int k = 0; k < 4; k++)
{ {
int i4x2 = b_k_n_permute(j + k * 2, i).data; int i4x2 = b_k_n_preshuffled(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf; input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf; input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
} }
...@@ -229,7 +219,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -229,7 +219,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[0]; int lo = input[0];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2; b_k_n_preshuffled(j + 0, i) = i4x2;
} }
{ {
...@@ -237,7 +227,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -237,7 +227,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[4]; int lo = input[4];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2; b_k_n_preshuffled(j + 2, i) = i4x2;
} }
{ {
...@@ -245,7 +235,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -245,7 +235,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[1]; int lo = input[1];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2; b_k_n_preshuffled(j + 4, i) = i4x2;
} }
{ {
...@@ -253,13 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -253,13 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[5]; int lo = input[5];
int i4x2 = (hi << 4) | lo; int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2; b_k_n_preshuffled(j + 6, i) = i4x2;
} }
} }
} }
a_m_k_device_buf.ToDevice(a_m_k.mData.data()); a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data());
DeviceMem workspace; DeviceMem workspace;
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
......
...@@ -1205,7 +1205,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1205,7 +1205,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType, BDataType,
BDataType, // BDataType,
ADataType,
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>, Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
...@@ -1221,7 +1222,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1221,7 +1222,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
// Cast after lds // Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
......
...@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})); using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc, __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx) const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)) : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
...@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time"); "wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0, static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible"); "wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
} }
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
...@@ -275,6 +287,72 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -275,6 +287,72 @@ struct ThreadwiseTensorSliceTransfer_v2
// loop over tensor and copy // loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess(); constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_tmp_vector;
using src_vector_t = typename decltype(src_tmp_vector)::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
is_src_valid);
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? dst_tmp_vector.template AsType<DstData>()[i]
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
dst_tmp_vector.template AsType<DstData>()[i];
// type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
}
else
{
static_for<0, num_access, 1>{}([&](auto idx_1d) { static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector; typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
...@@ -317,6 +395,7 @@ struct ThreadwiseTensorSliceTransfer_v2 ...@@ -317,6 +395,7 @@ struct ThreadwiseTensorSliceTransfer_v2
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step)); src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
} }
}); });
}
// move src coordinate back to slice origin (or not) // move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun) if constexpr(SrcResetCoordinateAfterRun)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment