Commit 9a3f75ee authored by mtgu0705's avatar mtgu0705
Browse files

fp8xint4 bpreshuffle function pass

parent ba5a6a24
......@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = true;
static constexpr bool PermuteB = false;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
......@@ -131,7 +131,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_preshuffled(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
switch(config.init_method)
{
......@@ -161,51 +160,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_preshuffled:" << b_k_n_preshuffled.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl;
// std::cout << "a_m_K size: " << sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "BDataType size: " << sizeof(BDataType) << std::endl;
// std::cout << "b_k_n size: " << sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "c_m_n size: " << sizeof(CDataType) * c_m_n_host_result.mDesc.GetElementSpaceSize()
// << std::endl;
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_preshuffled.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto gemm = DeviceGemmV2Instance{};
int NperXdl = gemm.GetPreShuffleParameters();
preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl);
// weight permute
if constexpr(PermuteB)
{
int K1 = KPerBlock;
int K0 = K / KPerBlock;
// weight pre-shuffle
int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8
int NLane = gemm.GetPreShuffleParameters();
int KLane = 64 / NLane;
// int K0, N, K1
for(int j = 0; j < K0; j++)
{
for(int i = 0; i < N; i++)
{
for(int jj = 0; jj < K1; jj++)
{
b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n_preshuffled(i * K + (j * K1 + jj));
}
}
}
}
else
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n=0;n<N;++n)
{
for(int i = 0; i < N; i++)
for(int k=0;k<K;++k)
{
for(int j = 0; j < K; j++)
{
b_k_n_permute(i * K + j) = b_k_n_preshuffled(i * K + j);
}
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
b_k_n_preshuffled(outputIndex) = b_k_n(n * K + k);
}
}
......@@ -218,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for(int k = 0; k < 4; k++)
{
int i4x2 = b_k_n_permute(j + k * 2, i).data;
int i4x2 = b_k_n_preshuffled(j + k * 2, i).data;
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
}
......@@ -229,7 +219,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[0];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 0, i) = i4x2;
b_k_n_preshuffled(j + 0, i) = i4x2;
}
{
......@@ -237,7 +227,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[4];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 2, i) = i4x2;
b_k_n_preshuffled(j + 2, i) = i4x2;
}
{
......@@ -245,7 +235,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[1];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 4, i) = i4x2;
b_k_n_preshuffled(j + 4, i) = i4x2;
}
{
......@@ -253,13 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int lo = input[5];
int i4x2 = (hi << 4) | lo;
b_k_n_permute(j + 6, i) = i4x2;
b_k_n_preshuffled(j + 6, i) = i4x2;
}
}
}
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data());
b_k_n_device_buf.ToDevice(b_k_n_preshuffled.mData.data());
DeviceMem workspace;
auto a_element_op = AElementOp{};
......
......@@ -1205,7 +1205,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
// BDataType,
ADataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
......@@ -1221,7 +1222,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
......
......@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
const Index& src_slice_origin_idx)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
......@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
......@@ -275,48 +287,115 @@ struct ThreadwiseTensorSliceTransfer_v2
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_tmp_vector;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
using src_vector_t = typename decltype(src_tmp_vector)::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
// copy data from src_buf into src_tmp_vector
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
is_src_valid);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
constexpr index_t pack_size = 8;
static_assert(SrcScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? dst_tmp_vector.template AsType<DstData>()[i]
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
dst_tmp_vector.template AsType<DstData>()[i];
// type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
}
else
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
}
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment