Commit 047cee2b authored by Anthony Chang's avatar Anthony Chang
Browse files

compiles

parent 68b71534
......@@ -54,25 +54,75 @@ using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// clang-format off
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout,B0Layout, CLayout, ADataType,B0DataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// clang-format on
using ReferenceGemm0Instance = ck::tensor_operation::host::
ReferenceGemm<ADataType, B0DataType, AccDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmGemm_Xdl_CShuffle<
ALayout,
B0Layout,
B1Layout,
CLayout,
ADataType,
B0DataType,
CDataType,
AccDataType,
CShuffleDataType,
AElementOp,
BElementOp,
CElementOp,
GemmDefault,
1,
256,
128, // MPerBlock
128, // NPerBlock
32, // KPerBlock
128, // Gemm1NPerBlock
32, // Gemm1KPerBlock
8, // AK1
8, // BK1
2, // B1K1
32, // MPerXDL
32, // NPerXDL
1, // MXdlPerWave
4, // NXdlPerWave
4, // Gemm1NXdlPerWave
S<4, 64, 1>, // ABlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<4, 64, 1>, // BBlockTransfer
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
true,
S<8, 32, 1>, // B1BlockTransfer
S<0, 2, 1>,
S<0, 2, 1>,
1,
4,
2,
false,
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8>; // CShuffleBlockTransferScalarPerVector_NPerBlock
using ReferenceGemm0Instance = ck::tensor_operation::host::ReferenceGemm<ADataType,
B0DataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
CElementOp>;
using ReferenceGemm1Instance = ck::tensor_operation::host::
ReferenceGemm<AccDataType, B1DataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
int main(int argc, char* argv[])
{
bool do_verification = true;
// int init_method = 1;
int init_method = 3;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
......@@ -87,13 +137,13 @@ int main(int argc, char* argv[])
// ck::index_t StrideC = 1024;
ck::index_t M = 256;
ck::index_t N = 256;
ck::index_t N = 128;
ck::index_t K = 32;
ck::index_t O = 256;
ck::index_t StrideA = 256;
ck::index_t StrideB0 = 256;
ck::index_t StrideB1 = 256;
ck::index_t StrideC = 256;
ck::index_t O = 128;
ck::index_t StrideA = 32;
ck::index_t StrideB0 = 32;
ck::index_t StrideB1 = 128;
ck::index_t StrideC = 128;
if(argc == 1)
{
......@@ -165,14 +215,16 @@ int main(int argc, char* argv[])
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{0.0, 1.0});
b1_n_o.GenerateTensorValue(GeneratorTensor_3<B1DataType>{-0.5, 0.5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b1_n_o.GenerateTensorValue(GeneratorTensor_2<B1DataType>{-5, 5});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_Diagonal<B0DataType>{});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n.GenerateTensorValue(GeneratorTensor_Sequential<1>{});
b1_n_o.GenerateTensorValue(GeneratorTensor_Diagonal<B1DataType>{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
......@@ -182,6 +234,7 @@ int main(int argc, char* argv[])
a_m_k_device_buf.ToDevice(a_m_k.mData.data());
b0_k_n_device_buf.ToDevice(b0_k_n.mData.data());
b1_n_o_device_buf.ToDevice(b1_n_o.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
......@@ -192,12 +245,15 @@ int main(int argc, char* argv[])
auto invoker = gemm.MakeInvoker();
auto argument = gemm.MakeArgument(static_cast<ADataType*>(a_m_k_device_buf.GetDeviceBuffer()),
static_cast<B0DataType*>(b0_k_n_device_buf.GetDeviceBuffer()),
static_cast<B1DataType*>(b1_n_o_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_o_device_buf.GetDeviceBuffer()),
M,
N,
K,
O,
StrideA,
StrideB0,
StrideB1,
StrideC,
a_element_op,
b_element_op,
......@@ -244,6 +300,15 @@ int main(int argc, char* argv[])
ref_gemm1_invoker.Run(ref_gemm1_argument);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std::cout << "b0_k_n(0, 0) = " << (float)b0_k_n(0, 0) << ", b0_k_n(1, 0) = " << (float)b0_k_n(1, 0)
<< ", b0_k_n(0, 1) = " << (float)b0_k_n(0, 1) << ", b0_k_n(1, 1) = " << (float)b0_k_n(1, 1)
<< std::endl;
return ck::utils::check_err(c_m_o_device_result.mData, c_m_o_host_result.mData) ? 0 : 1;
}
......
......@@ -25,16 +25,27 @@ constexpr LoopScheduler make_default_loop_scheduler()
#endif // if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
}
// Blockwise gemm supporting both regular XDL output M2_M3_M4_M2 and transposed XDL output
// M2_N2_N3_N4. The latter is similar to "SourceSwap" seen in Tensile
// TODO ANT: rename class to reflect the above fact
template <index_t BlockSize,
typename FloatAB,
typename FloatAcc,
typename AK0MK1BlockDesc,
typename AK0MK1BlockDesc, // could be thread desc
typename BK0NK1BlockDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool TransposeC = false,
index_t AMmaKStride = KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops,
index_t BMmaKStride = KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
{
static constexpr auto I0 = Number<0>{};
......@@ -46,23 +57,28 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
static constexpr index_t WaveSize = get_warp_size();
static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
static constexpr index_t KPerBlock =
BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
// static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
// static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
// static constexpr index_t KPerBlock =
// BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack>{};
static constexpr auto xdlops_gemm = XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, TransposeC>{};
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
// StaticBuffer<AddressSpaceEnum::Vgpr,
// FloatAcc,
// MRepeat * NRepeat * xdlops_gemm.GetRegSizePerXdlops(),
// true>
// c_thread_buf_;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
FloatAcc,
MRepeat * NRepeat,
......@@ -92,7 +108,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -103,7 +119,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
......@@ -135,10 +151,30 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
return make_tuple(c_thread_m, c_thread_n);
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(
Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
BK0NK1BlockDesc::IsKnownAtCompileTime(),
#if 0
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
......@@ -148,6 +184,27 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
"wrong!");
}
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1(
const BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1& other)
: a_thread_copy_(other.a_origin), b_thread_copy_(other.b_origin)
{
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
......@@ -174,6 +231,21 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat>{},
Number<NRepeat>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
......@@ -239,33 +311,10 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
{
return transform_tensor_descriptor(
AK0MK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<A_K0>{}, Number<A_K1>{})),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
__host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
{
return transform_tensor_descriptor(
BK0NK1BlockDesc{},
make_tuple(
make_merge_transform_v3_division_mod(make_tuple(Number<B_K0>{}, Number<B_K1>{})),
make_unmerge_transform(
make_tuple(Number<NRepeat>{}, Number<NWaves>{}, Number<NPerXDL>{}))),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k = MakeABlockDescriptor_M0_M1_M2_K();
static constexpr auto b_block_desc_n0_n1_n2_k = MakeBBlockDescriptor_N0_N1_N2_K();
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
// NOTE ANT: a_block_buf for the 2nd gemm is vgpr buffer
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
__device__ void Run(const ABlockBuffer& a_block_buf,
const BBlockBuffer& b_block_buf,
......@@ -276,33 +325,65 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
b_thread_desc_.GetElementSpaceSize());
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
a_thread_buf);
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
// static_for<0, KPerBlock, KPack * xdlops_gemm.K0PerXdlops>{}([&](auto k) {
static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=kpack*[0, 1, 2]
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read A1 without stride
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
a_thread_buf);
static_for<0, KPerThread, KPack>{}([&](auto k) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read B with stride
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_buf);
#if 0
if (!TransposeC && hipThreadIdx_x % 32 < 8) {
printf("bid %zd tid %zd, mma tile %d %d %d, a[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, b[0:3] = %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f, %.0f\n",
hipBlockIdx_x, hipThreadIdx_x, m0.value, n0.value, k.value,
// (float)a_thread_buf[Number<0>{}],
// (float)a_thread_buf[Number<1>{}],
// (float)a_thread_buf[Number<2>{}],
// (float)a_thread_buf[Number<3>{}],
// (float)b_thread_buf[Number<0>{}],
// (float)b_thread_buf[Number<1>{}],
// (float)b_thread_buf[Number<2>{}],
// (float)b_thread_buf[Number<3>{}]
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 0))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 1))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 2))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 3))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 4))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 5))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 6))>{}],
(float)b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, /* k + */ 7))>{}]
);
}
#endif
vector_type<FloatAB, KPack> a_thread_vec;
vector_type<FloatAB, KPack> b_thread_vec;
// xdlops_gemm.K0PerXdlops
// TODO ANT: add appropriate iteration delta
static_for<0, KPack, 1>{}([&](auto i) {
a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
[Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
[Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
});
using mfma_input_type =
......@@ -337,7 +418,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
A_K1,
......@@ -347,16 +428,17 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
FloatAB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerThread>,
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
B_K1,
B_K1>;
AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()};
AThreadCopy a_thread_copy_; // {CalculateAThreadOriginDataIndex()};
BThreadCopy b_thread_copy_; // {CalculateBThreadOriginDataIndex()};
};
#if 0
// Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
// CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
// the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
......@@ -584,5 +666,6 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
KPack>{};
}
};
#endif
} // namespace ck
......@@ -81,7 +81,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
src_block_slice_origin + thread_data_idx_begin);
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
......
......@@ -1145,9 +1145,52 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc, src_data_coord);
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if constexpr(SrcBuffer::IsDynamicBuffer())
{
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
}
else if constexpr(SrcBuffer::IsStaticBuffer())
{
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_ref_to_origin_disp_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
// apply type convert
src_tmp_vector.template AsType<SrcData>()(i) =
src_buf[Number<src_offset>{}];
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
......@@ -1184,4 +1227,93 @@ struct ThreadwiseTensorSliceTransfer_v4
SrcCoord src_ref_coord_;
};
// Do NOT involve any tensor coordinates with StaticBuffer
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t DstVectorDim,
index_t DstScalarPerVector,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3_Static
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3_Static()
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
"wrong! Not divisible");
}
template <typename SrcSliceOriginIdx, typename DstSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc&,
const SrcSliceOriginIdx&,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! SliceOrigin need to known at compile-time");
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
"wrong! Buffer need to be StaticBuffer");
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
// scalar per access on each dim
constexpr auto dst_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
constexpr auto dst_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(dst_scalar_per_access)>>;
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
using dst_vector_t = typename vector_type_maker<DstData, DstScalarPerVector>::type::type;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = src_buf[Number<src_offset>{}];
});
});
}
};
} // namespace ck
......@@ -30,6 +30,17 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template <MfmaInstr instr>
struct mfma_type;
......@@ -579,7 +590,11 @@ struct MfmaSelector
static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
template <typename base_type,
index_t MPerXdlops,
index_t NPerXdlops,
index_t KPack,
bool TransposeC = false>
struct XdlopsGemm
{
static constexpr auto I0 = Number<0>{};
......@@ -612,6 +627,8 @@ struct XdlopsGemm
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
}
// XDL output supporting C = A * B
// M2_N2 -> M2_M3_M4_N2
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
......@@ -645,6 +662,41 @@ struct XdlopsGemm
Sequence<7>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(mfma_instr.num_threads_per_blk),
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
mfma_instr.num_input_blks,
mfma_instr.group_size))),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6, 7>{}));
}
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
......@@ -698,7 +750,14 @@ struct XdlopsGemm
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
if constexpr (!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_b_wave[k], p_a_wave[k], p_c_thread);
}
});
}
......
......@@ -61,6 +61,7 @@ struct StaticBufferTupleOfVector
static constexpr auto s_per_v = Number<ScalarPerVector>{};
static constexpr auto num_of_v_ = Number<NumOfVector>{};
static constexpr auto s_per_buf = s_per_v * num_of_v_;
__host__ __device__ constexpr StaticBufferTupleOfVector() : base{} {}
......@@ -70,6 +71,7 @@ struct StaticBufferTupleOfVector
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
__host__ __device__ static constexpr index_t Size() { return s_per_buf; };
// Get S
// i is offset of S
template <index_t I>
......
......@@ -78,4 +78,10 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
}
template <index_t... Is>
__host__ __device__ constexpr Tuple<Number<Is>...> to_tuple(Sequence<Is...>)
{
return Tuple<Number<Is>...>(Number<Is>{}...);
}
} // namespace ck
......@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
max_err = err > max_err ? err : max_err;
err_count++;
if(err_count < 5)
if(err_count < 128)
{
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment