Commit d676da85 authored by aska-0096's avatar aska-0096
Browse files

Clang format, Add gfx1101, gfx1102 support of FMHA example

parent 6e2c6159
...@@ -49,32 +49,32 @@ using DeviceConvFwdInstance = ...@@ -49,32 +49,32 @@ using DeviceConvFwdInstance =
InElementOp, InElementOp,
WeiElementOp, WeiElementOp,
OutElementOp, OutElementOp,
ConvSpec, // ConvForwardSpecialization ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization GemmSpec, // GemmSpecialization
1, // Prefetch stage 1, // Prefetch stage
256, // BlockSize 256, // BlockSize
128, // MPerBlock 128, // MPerBlock
128, // NPerBlock 128, // NPerBlock
32, // KPerBlock 32, // KPerBlock
8, // K1 8, // K1
16, // MPerWMMA 16, // MPerWMMA
16, // NPerWMMA 16, // NPerWMMA
4, // MRepeat 4, // MRepeat
2, // NRepeat 2, // NRepeat
S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<4, 8, 8>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim 2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector 1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1 1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM true, // ABlockLdsExtraM
S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<4, 8, 8>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim 2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector 1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1 1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN true, // BBlockLdsExtraN
4, 4,
2, 2,
S<1, 32, 1, 8>, S<1, 32, 1, 8>,
......
...@@ -5,7 +5,7 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 ...@@ -5,7 +5,7 @@ add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16
add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp) add_example_executable(example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16 grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp) add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
endif() endif()
...@@ -19,7 +19,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft ...@@ -19,7 +19,7 @@ add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_scale_soft
add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16) add_dependencies(example_gemm_scale_softmax_gemm example_grouped_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16)
if(GPU_TARGETS MATCHES "gfx1100") if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_custom_target(example_gemm_scale_softmax_gemm_wmma) add_custom_target(example_gemm_scale_softmax_gemm_wmma)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16) add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16)
add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16) add_dependencies(example_gemm_scale_softmax_gemm_wmma example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16)
......
...@@ -418,7 +418,7 @@ struct BlockwiseGemmWMMA ...@@ -418,7 +418,7 @@ struct BlockwiseGemmWMMA
} }
protected: protected:
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}), make_tuple(Number<WmmaK / A_K1>{}, Number<MRepeat>{}, I1, I1, Number<A_K1>{}),
make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{})); make_tuple(Number<A_K1>{}, Number<WmmaK>{}, Number<A_K1>{}, Number<A_K1>{}, Number<1>{}));
......
...@@ -99,7 +99,7 @@ template <index_t NDimSpatial, ...@@ -99,7 +99,7 @@ template <index_t NDimSpatial,
typename DsLayout, typename DsLayout,
typename ELayout, typename ELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename AccDataType, typename AccDataType,
typename CShuffleDataType, typename CShuffleDataType,
typename DsDataType, typename DsDataType,
...@@ -180,7 +180,6 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -180,7 +180,6 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu; static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu;
static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu; static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu;
static constexpr auto conv_to_gemm_transformer = static constexpr auto conv_to_gemm_transformer =
TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{}; TransformConvFwdToGemm<NDimSpatial, ConvForwardSpecialization>{};
......
...@@ -712,8 +712,8 @@ struct GridwiseGemmMultipleD_Wmma ...@@ -712,8 +712,8 @@ struct GridwiseGemmMultipleD_Wmma
const auto M = e_grid_desc_m_n.GetLength(I0); const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1); const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock; const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock; const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n, e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})), make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
......
# find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' find . -name deps -prune -o -name build -prune -o -iname '*.h' -o -iname '*.hpp' -o -iname '*.cpp' -o -iname '*.h.in' -o -iname '*.hpp.in' -o -iname '*.cpp.in' -o -iname '*.cl' -o -iname '*.cuh' -o -iname '*.cu' -o -iname '*.inc' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}' # git status --porcelain | awk '$1 != "D" && (match($2, "\\.cpp|hpp|inc")) {print $2}' | xargs -n 1 -P 16 -I{} -t sh -c 'clang-format-10 -i -style=file {}'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment