Commit 56de337f authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 41b920e2 687d2b7e
add_example_executable(example_batched_gemm_bias_e_permute_xdl_fp16 batched_gemm_bias_e_permute_xdl_fp16.cpp)
if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_bias_e_permute_wmma_fp16 batched_gemm_bias_e_permute_wmma_fp16.cpp)
endif()
......@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Add;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
static constexpr auto ABSpec = ck::tensor_operation::device::TensorSpecialization::Packed;
static constexpr auto ASpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto BSpec = ck::tensor_operation::device::TensorSpecialization::Default;
static constexpr auto DESpec = ck::tensor_operation::device::TensorSpecialization::Default;
using DeviceOpInstanceKKNN =
......@@ -55,43 +56,44 @@ using DeviceOpInstanceKKNN =
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementOp,
BElementOp,
CDEElementOp,
GemmSpec,
ABSpec,
ABSpec,
ASpec,
BSpec,
DESpec,
256,
1,
128,
256,
8,
8,
64,
64,
64,
4,
16,
16,
1,
4,
4,
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
S<4, 64, 1>,
S<4, 32, 1>,
S<1, 0, 2>,
S<1, 0, 2>,
2,
8,
8,
4,
4,
true,
1,
1,
S<1, 32, 1, 8>,
S<1, 64, 1, 2>,
8>;
using DeviceOpInstance = DeviceOpInstanceKKNN;
......@@ -251,21 +253,6 @@ int main(int argc, char* argv[])
ck::index_t K0 = 2048;
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
if(argc == 1)
{
// use default case
......@@ -276,13 +263,43 @@ int main(int argc, char* argv[])
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
G0 = std::stoi(argv[4]);
G1 = std::stoi(argv[5]);
M0 = std::stoi(argv[6]);
M1 = std::stoi(argv[7]);
N0 = std::stoi(argv[8]);
N1 = std::stoi(argv[9]);
K0 = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4-10: G0, G1, M0, M1, N0, N1, K0\n");
exit(0);
}
// A[G0, G1, M0, M1, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{G1 * M0 * M1 * K0, M0 * M1 * K0, M1 * K0, K0, 1};
// B[G0, G1, N0, N1, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, G1, N0, N1, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{G1 * N0 * N1 * K0, N0 * N1 * K0, N1 * K0, K0, 1};
// D[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> d_gs_ms_ns_strides{G1 * N0 * N1, N0 * N1, 0, 0, N1, 1};
// E[G0, G1, M0, N0, M1, N1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, G1, M0, M1, N0, N1};
std::vector<ck::index_t> e_gs_ms_ns_strides{
G1 * M0 * N0 * M1 * N1, M0 * N0 * M1 * N1, N0 * M1 * N1, N1, M1 * N1, 1};
Tensor<ADataType> a_gs_ms_ks(a_gs_ms_ks_lengths, a_gs_ms_ks_strides);
Tensor<BDataType> b_gs_ns_ks(b_gs_ns_ks_lengths, b_gs_ns_ks_strides);
Tensor<DDataType> d_gs_ms_ns(d_gs_ms_ns_lengths, d_gs_ms_ns_strides);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
template <typename BiasLay, typename ResidualLay>
struct LayoutSetting
......@@ -42,41 +42,42 @@ using DeviceConvFwdInstance =
OutputLayout<NDimSpatial>,
InKernelDataType,
WeiKernelDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<BiasKernelDataType, ResidualKernelDataType>,
OutKernelDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
4, // K0PerBlock
1, // Prefetch stage
128, // BlockSize
64, // MPerBlock
64, // NPerBlock
64, // KPerBlock
8, // K1
16, // MPerWMMA
16, // NPerWMMA
4, // MRepeat
2, // NRepeat
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
1, // NRepeat
S<4, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<4, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
4,
2,
S<1, 32, 1, 8>,
1,
1,
S<1, 16, 1, 8>,
8>;
template <ck::index_t NDimSpatial>
......@@ -277,9 +278,10 @@ bool run_grouped_conv_fwd_bias_relu_add_example(int argc, char* argv[])
switch(conv_param.num_dim_spatial_)
{
case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2: return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
// case 1: return run_grouped_conv_fwd_bias_relu_add<1>(config, conv_param);
case 2:
return run_grouped_conv_fwd_bias_relu_add<2>(config, conv_param);
// case 3: return run_grouped_conv_fwd_bias_relu_add<3>(config, conv_param);
}
return false;
......
if(GPU_TARGETS MATCHES "gfx11")
add_example_executable(example_batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_wmma_fp16 batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp)
add_example_executable(example_self_attention_forward_wmma_fp16 self_attention_forward_wmma_fp16.cpp)
add_example_executable(example_cross_attention_forward_wmma_fp16 cross_attention_forward_wmma_fp16.cpp)
add_example_executable(example_multi_query_attention_forward_wmma_fp16 multi_query_attention_forward_wmma_fp16.cpp)
add_example_executable(example_grouped_query_attention_forward_wmma_fp16 grouped_query_attention_forward_wmma_fp16.cpp)
endif()
add_custom_target(example_gemm_scale_softmax_gemm)
add_example_executable(example_batched_gemm_scale_softmax_gemm_xdl_fp16 batched_gemm_scale_softmax_gemm_xdl_fp16.cpp)
......@@ -20,4 +29,3 @@ add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_sc
add_example_executable(example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16 batched_gemm_scale_softmax_gemm_permute_xdl_bf16.cpp)
add_example_dependencies(example_gemm_scale_softmax_gemm example_batched_gemm_scale_softmax_gemm_permute_xdl_bf16)
if(GPU_TARGETS MATCHES "gfx11")
add_custom_target(example_fpAintB_gemm_wmma)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
endif()
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <iostream>
#include <initializer_list>
#include <numeric>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_fpAintB_gemm.hpp"
struct ProblemSize final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = 4096;
ck::index_t StrideB = 4096;
ck::index_t StrideC = 4096;
};
struct ExecutionConfig final
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
};
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
template <typename IntType>
struct UnsignedWeightPreprocessor
{
};
template <>
struct UnsignedWeightPreprocessor<int8_t>
{
using UnsignedWeight = Tensor<uint8_t>;
using SignedWeight = Tensor<int8_t>;
static UnsignedWeight convert(SignedWeight const& Input)
{
UnsignedWeight Output = Input.template CopyAsType<uint8_t>();
auto f_kn = [&](auto k, auto n) {
const uint8_t adder = 128;
int8_t v_signed_weight;
uint8_t v_unsigned_weight;
ck::tensor_operation::element_wise::PassThrough{}(v_signed_weight, Input(k, n));
v_unsigned_weight = ck::type_convert<uint8_t>(v_signed_weight) + adder;
Output(k, n) = v_unsigned_weight;
};
make_ParallelTensorFunctor(f_kn, Input.mDesc.GetLengths()[0], Input.mDesc.GetLengths()[1])(
std::thread::hardware_concurrency());
return Output;
}
UnsignedWeight operator()(SignedWeight const& Input) { return convert(Input); }
};
inline bool
parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
{
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
}
else if(argc == 10)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
config.time_kernel = std::stoi(argv[3]);
problem_size.M = std::stoi(argv[4]);
problem_size.N = std::stoi(argv[5]);
problem_size.K = std::stoi(argv[6]);
problem_size.StrideA = std::stoi(argv[7]);
problem_size.StrideB = std::stoi(argv[8]);
problem_size.StrideC = std::stoi(argv[9]);
}
else
{
std::cerr << "arg1: verification (0=no, 1=yes)" << std::endl
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)"
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl;
return false;
}
return true;
}
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment