Commit 4d595f02 authored by aska-0096's avatar aska-0096
Browse files

Bug fix; Rename example to fp16int8; clang-format

parent fdfb2f61
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_b_scale gemm_multiply_multiply_xdl_fp8_b_scale.cpp) add_example_executable(example_gemm_fp16int8_b_scale gemm_fp16int8_b_scale.cpp)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
...@@ -37,6 +37,7 @@ using A0DataType = FP16; ...@@ -37,6 +37,7 @@ using A0DataType = FP16;
// using A1DataType = F32; // using A1DataType = F32;
// using B0DataType = FP8; // using B0DataType = FP8;
// using B1DataType = F32; // using B1DataType = F32;
using QuantDataType = int8_t;
using B0DataType = uint8_t; using B0DataType = uint8_t;
using B1DataType = FP16; using B1DataType = FP16;
using AccDataType = F32; using AccDataType = F32;
...@@ -69,21 +70,54 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X ...@@ -69,21 +70,54 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_BScale_X
A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, A0DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_N, Scale_Block_K, 256, Scale_Block_N, Scale_Block_K,
128, 128, 128, 128,
128, 128,
// 16, 16, // 16, 16,
8, 8, 8, 8,
16, 16, 16, 16,
4, 4, 4, 4,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, // S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
// S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, // S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0,
1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>,
// ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>; ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3>;
// clang-format on // clang-format on
template <typename IntType>
struct UnsignedWeightPreprocessor
{
};
template <>
struct UnsignedWeightPreprocessor<int8_t>
{
using UnsignedWeight = Tensor<uint8_t>;
using SignedWeight = Tensor<int8_t>;
static UnsignedWeight convert(SignedWeight const& Input)
{
UnsignedWeight Output = Input.template CopyAsType<uint8_t>();
auto f_kn = [&](auto k, auto n) {
const uint8_t adder = 128;
int8_t v_signed_weight;
uint8_t v_unsigned_weight;
ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n));
v_unsigned_weight = ck::type_convert<uint8_t>(v_signed_weight) + adder;
Output(k, n) = v_unsigned_weight;
};
make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return Output;
}
UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); }
};
int main(int argc, char* argv[]) int main(int argc, char* argv[])
{ {
bool do_verification = true; bool do_verification = true;
...@@ -154,7 +188,8 @@ int main(int argc, char* argv[]) ...@@ -154,7 +188,8 @@ int main(int argc, char* argv[])
// (K + Scale_Block_K - 1) / Scale_Block_K, // (K + Scale_Block_K - 1) / Scale_Block_K,
// Scale_Stride_AM, // Scale_Stride_AM,
// A0Layout{})); // A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor<QuantDataType> quant_b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
// Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, Tensor<B1DataType> b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K,
(N + Scale_Block_N - 1) / Scale_Block_N, (N + Scale_Block_N - 1) / Scale_Block_N,
Scale_Stride_BN, Scale_Stride_BN,
...@@ -164,7 +199,7 @@ int main(int argc, char* argv[]) ...@@ -164,7 +199,7 @@ int main(int argc, char* argv[])
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
// std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; // std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; std::cout << "b0_k_n: " << quant_b0_k_n.mDesc << std::endl;
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
...@@ -174,36 +209,39 @@ int main(int argc, char* argv[]) ...@@ -174,36 +209,39 @@ int main(int argc, char* argv[])
case 0: break; case 0: break;
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_2<QuantDataType>{-2, 2});
// a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0}); // a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0}); b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break; break;
case 2: case 2:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_1<QuantDataType>{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{}); // a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break; break;
case 3: case 3:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_2<QuantDataType>{-2, 2});
// a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{}); // a1_m_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{}); b1_k_n.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break; break;
case 4: case 4:
a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{}); a0_m_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_1<QuantDataType>{});
// a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0}); // a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0}); b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5}); quant_b0_k_n.GenerateTensorValue(GeneratorTensor_3<QuantDataType>{-0.5, 0.5});
// a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0}); // a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0}); b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
} }
#endif #endif
UnsignedWeightPreprocessor<QuantDataType> preprocessor;
Tensor<B0DataType> b0_k_n = preprocessor(quant_b0_k_n);
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
// DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); // DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
...@@ -282,8 +320,8 @@ int main(int argc, char* argv[]) ...@@ -282,8 +320,8 @@ int main(int argc, char* argv[])
{ {
for(int k = 0; k < K; k++) for(int k = 0; k < K; k++)
{ {
b_k_n(k, n) = ck::type_convert<float>(b0_k_n(k, n)) * b_k_n(k, n) = ck::type_convert<float>(quant_b0_k_n(k, n)) *
b1_k_n(k / Scale_Block_K, n / Scale_Block_N); ck::type_convert<float>(b1_k_n(k / Scale_Block_K, n / Scale_Block_N));
} }
} }
......
...@@ -7,15 +7,15 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -7,15 +7,15 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
layernorm2d_fwd_args a, layernorm2d_fwd_args a,
const ck_tile::stream_config& s) const ck_tile::stream_config& s)
{ {
if(t.data_type.compare("fp16") == 0) if(t.data_type.compare("fp32") == 0)
{ {
using XDataType = ck_tile::half_t; using XDataType = float;
using YDataType = ck_tile::half_t; using YDataType = float;
using GammaDataType = ck_tile::half_t; using GammaDataType = float;
using BetaDataType = ck_tile::half_t; using BetaDataType = float;
#ifdef SAVE_MEAN_INV_STD #ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t; using MeanDataType = float;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = float;
#else #else
using MeanDataType = ck_tile::null_type; using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type; using InvStdDataType = ck_tile::null_type;
...@@ -82,13 +82,13 @@ int main(int argc, char* argv[]) ...@@ -82,13 +82,13 @@ int main(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec"); std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v"); int do_validation = arg_parser.get_int("v");
using XDataType = ck_tile::half_t; using XDataType = float;
using YDataType = ck_tile::half_t; using YDataType = float;
using GammaDataType = ck_tile::half_t; using GammaDataType = float;
using BetaDataType = ck_tile::half_t; using BetaDataType = float;
#ifdef SAVE_MEAN_INV_STD #ifdef SAVE_MEAN_INV_STD
using MeanDataType = ck_tile::half_t; using MeanDataType = float;
using InvStdDataType = ck_tile::half_t; using InvStdDataType = float;
#else #else
using MeanDataType = ck_tile::null_type; using MeanDataType = ck_tile::null_type;
using InvStdDataType = ck_tile::null_type; using InvStdDataType = ck_tile::null_type;
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -276,11 +276,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -276,11 +276,11 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
typename BBlockBuffer, typename BBlockBuffer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer, typename CThreadBuffer,
// typename AScaleGridBuffer, // typename AScaleGridBuffer,
// typename AScaleGridDesc, // typename AScaleGridDesc,
// typename AScaleThreadDesc, // typename AScaleThreadDesc,
// typename AScaleThreadTransfer, // typename AScaleThreadTransfer,
// typename AScaleThreadTransferStep, // typename AScaleThreadTransferStep,
typename BScaleGridBuffer, typename BScaleGridBuffer,
typename BScaleGridDesc, typename BScaleGridDesc,
typename BScaleThreadDesc, typename BScaleThreadDesc,
...@@ -332,7 +332,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -332,7 +332,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// b_scale_thread_desc.GetElementSpaceSize()); // b_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ck::half_t>( auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ck::half_t>(
b_scale_thread_desc.GetElementSpaceSize()); b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1 // Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
...@@ -477,7 +476,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra ...@@ -477,7 +476,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
make_tuple(I0, I0), make_tuple(I0, I0),
b_scale_thread_buf); b_scale_thread_buf);
// a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, a_scale_thread_copy_step); // a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
// a_scale_thread_copy_step);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step); b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -24,12 +24,12 @@ template <typename ALayout, ...@@ -24,12 +24,12 @@ template <typename ALayout,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
typename ADataType, typename ADataType,
// typename AScaleType, // typename AScaleType,
typename BDataType, typename BDataType,
typename BScaleType, typename BScaleType,
typename DsDataType, typename DsDataType,
typename EDataType, typename EDataType,
// index_t ScaleBlockM, // index_t ScaleBlockM,
index_t ScaleBlockN, index_t ScaleBlockN,
index_t ScaleBlockK, index_t ScaleBlockK,
typename AElementwiseOperation, typename AElementwiseOperation,
......
...@@ -188,10 +188,7 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3 ...@@ -188,10 +188,7 @@ struct DeviceGemmMultiD_BScale_Xdl_CShuffle_V3
}; };
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
(BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave && (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave) ? 1 : 2;
MPerBlock * NPerBlock / BlockSize > 64)
? 1
: 2;
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" #include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
// #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_b_scale.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1_b_scale.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...@@ -34,7 +34,7 @@ template <typename GridwiseGemm, ...@@ -34,7 +34,7 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Full>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
...@@ -1274,34 +1274,35 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1274,34 +1274,35 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
// A matrix blockwise copy // A matrix blockwise copy
auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_b_scale< auto a_blockwise_copy =
ThisThreadBlock, ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock,
AElementwiseOperation, AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<AK0Number, MPerBlock, AK1Number>, Sequence<AK0Number, MPerBlock, AK1Number>,
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
ADataType, ADataType,
LDSTypeA, ADataType,
decltype(a_grid_desc_ak0_m_ak1), decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferDstScalarPerVector_AK1,
1, 1,
1, 1,
AThreadTransferSrcResetCoordinateAfterRun, AThreadTransferSrcResetCoordinateAfterRun,
true, true,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1, BlockwiseGemmPipe::GlobalBufferNum>(
make_multi_index(0, m_block_data_idx_on_grid, 0), a_grid_desc_ak0_m_ak1,
a_element_op, make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1, a_element_op,
make_multi_index(0, 0, 0), a_block_desc_ak0_m_ak1,
ck::tensor_operation::element_wise::PassThrough{}); make_multi_index(0, 0, 0),
ck::tensor_operation::element_wise::PassThrough{});
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_b_scale< auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_b_scale<
...@@ -1562,16 +1563,18 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3 ...@@ -1562,16 +1563,18 @@ struct GridwiseGemmMultiD_BScale_xdl_cshuffle_v3
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
const auto c_ds_desc_refs = concat_tuple_of_reference( const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie([&](auto i) -> const auto& // return type should be reference generate_tie(
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, [&](auto i) -> const auto& // return type should be reference
Number<NumDTensor>{})); { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference( const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf), tie(c_shuffle_block_buf),
generate_tie([&](auto i) -> const auto& // return type should be reference generate_tie(
{ return ds_grid_buf[i]; }, [&](auto i) -> const auto& // return type should be reference
Number<NumDTensor>{})); { return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// tuple of starting index of C/Ds blockwise copy // tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat( const auto idx_c_ds_block_begin = container_concat(
......
...@@ -216,7 +216,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -216,7 +216,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
.template SetAsType<src_vector_t>( .template SetAsType<src_vector_t>(
src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]); src_data_idx_seq, src_vector_container.template AsType<src_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr { constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
...@@ -229,7 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -229,7 +230,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
}); });
return move_on_dim_; return move_on_dim_;
}(); }
();
// move src coord // move src coord
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
...@@ -499,7 +501,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -499,7 +501,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
is_dst_valid, is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]); dst_vector_container.template AsType<dst_vector_t>()[I0]);
constexpr auto move_on_dim = [&]() constexpr { constexpr auto move_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_; StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
...@@ -512,7 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale ...@@ -512,7 +515,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_b_scale
}); });
return move_on_dim_; return move_on_dim_;
}(); }
();
// move dst coord // move dst coord
static_for<0, nDim, 1>{}([&](auto i) { static_for<0, nDim, 1>{}([&](auto i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment