"model/models/vscode:/vscode.git/clone" did not exist on "ec9eb28f4c3481d58c6da38ee488cb8cd5379256"
Commit 837d9056 authored by mtgu0705's avatar mtgu0705
Browse files

Added b preshuffle pipeline v3 support.

parent e53e764d
...@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
static constexpr bool PermuteA = false; static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false; static constexpr bool PermuteB = false;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off // clang-format off
#if 0
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle< ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout, ALayout, BLayout, CLayout,
...@@ -38,7 +38,7 @@ using DeviceGemmV2Instance = ...@@ -38,7 +38,7 @@ using DeviceGemmV2Instance =
AElementOp, BElementOp, CElementOp, GemmDefault, AElementOp, BElementOp, CElementOp, GemmDefault,
256, 256,
128, 128, 128, 128,
KPerBlock, 16, 32, 256, 16, 32,
32, 32, 32, 32,
4, 1, 4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
...@@ -47,7 +47,26 @@ using DeviceGemmV2Instance = ...@@ -47,7 +47,26 @@ using DeviceGemmV2Instance =
2, 32, 32, 0, 2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 4, 1, 1, S<1, 32, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
#else
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
256, 256,
128, 16, 32,
32, 32,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>;
#endif
// clang-format on // clang-format on
template <typename ProblemType> template <typename ProblemType>
......
...@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()); a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>( auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()); b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs; StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0); constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1 // Global prefetch A1 B1
...@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(I0, I0, I0, k0, I0, I0), make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
// Initialize C // Initialize C
c_thread_buf.Clear(); c_thread_buf.Clear();
...@@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -594,9 +604,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0, I0,
ik))>{}]; ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf] b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset( [Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0), I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
} }
else else
{ {
...@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0), I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(mfma_reg_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(mfma_reg_buf));
} }
HotLoopScheduler(m0); HotLoopScheduler(m0);
...@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}]; make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
} }
else else
{ {
...@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
} }
EpilogueScheduler_1(m0); EpilogueScheduler_1(m0);
...@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple( a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}]; (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset( b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0), Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
EpilogueScheduler_2(); EpilogueScheduler_2();
} }
...@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}]; make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset( b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; make_tuple(n0, I0, k0, ik))>{}];
}); });
...@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0), make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf); a_thread_buf);
}); });
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
EpilogueScheduler_2(); EpilogueScheduler_2();
} }
...@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I ...@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_; using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment