Commit 4947639c authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 17cf8179 d39c3f5d
resources:
repositories:
- repository: pipelines_repo
type: github
endpoint: ROCm
name: ROCm/ROCm
variables:
- group: common
- template: /.azuredevops/variables-global.yml@pipelines_repo
trigger:
batch: true
branches:
include:
- develop
paths:
exclude:
- .github
- docs
- '.*.y*ml'
- '*.md'
- Jenkinsfile
- LICENSE
pr:
autoCancel: true
branches:
include:
- develop
paths:
exclude:
- .github
- docs
- '.*.y*ml'
- '*.md'
- Jenkinsfile
- LICENSE
drafts: false
jobs:
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo
...@@ -652,8 +652,8 @@ def process_results(Map conf=[:]){ ...@@ -652,8 +652,8 @@ def process_results(Map conf=[:]){
} }
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;COMPILER_VERSION= CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;ROCMVERSION=6.1;
0 21 * * * % ROCMVERSION=6.1;COMPILER_VERSION=;COMPILER_COMMIT= 0 21 * * * % ROCMVERSION=6.1;hipTensor_test=true
0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false 0 19 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-staging;COMPILER_COMMIT=;USE_SCCACHE=false
0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false 0 17 * * * % BUILD_DOCKER=true;DL_KERNELS=true;COMPILER_VERSION=amd-mainline-open;COMPILER_COMMIT=;USE_SCCACHE=false
0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : "" 0 15 * * * % BUILD_INSTANCES_ONLY=true;RUN_CODEGEN_TESTS=false;RUN_PERFORMANCE_TESTS=false;USE_SCCACHE=false''' : ""
...@@ -701,8 +701,8 @@ pipeline { ...@@ -701,8 +701,8 @@ pipeline {
description: "Select whether to build DL kernels (default: OFF)") description: "Select whether to build DL kernels (default: OFF)")
booleanParam( booleanParam(
name: "hipTensor_test", name: "hipTensor_test",
defaultValue: true, defaultValue: false,
description: "Use the CK build to verify hipTensor build and tests (default: ON)") description: "Use the CK build to verify hipTensor build and tests (default: OFF)")
string( string(
name: 'hipTensor_branch', name: 'hipTensor_branch',
defaultValue: 'mainline', defaultValue: 'mainline',
...@@ -911,9 +911,8 @@ pipeline { ...@@ -911,9 +911,8 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx90a;gfx1030;gfx1101" \
-D INSTANCES_ONLY=ON \ -D INSTANCES_ONLY=ON \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j32 """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
} }
steps{ steps{
buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args) buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
......
...@@ -35,6 +35,10 @@ target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composa ...@@ -35,6 +35,10 @@ target_link_libraries(client_grouped_convnd_fwd_scaleadd_ab_int8 PRIVATE composa
add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16 add_executable(client_grouped_convnd_fwd_bilinear_residual_fp16
grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp) grouped_convnd_fwd_bilinear/grouped_conv_fwd_bilinear_residual_fp16.cpp)
target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations) target_link_libraries(client_grouped_convnd_fwd_bilinear_residual_fp16 PRIVATE composable_kernel::device_conv_operations)
# Fwd convscale
add_executable(client_conv3d_fwd_convscale_fp8
grouped_convnd_fwd_convscale/conv3d_fwd_convscale_fp8.cpp)
target_link_libraries(client_conv3d_fwd_convscale_fp8 PRIVATE composable_kernel::device_conv_operations)
# Bwd data bilinear # Bwd data bilinear
add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16 add_executable(client_grouped_convnd_bwd_data_bilinear_residual_fp16
grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp) grouped_convnd_bwd_data_bilinear/grouped_conv_bwd_data_bilinear_residual_fp16.cpp)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_convscale.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths,
const std::size_t& ds_size)
{
// G * N * C * <output spatial lengths product> * (2 * K * <filter spatial lengths product> +
// <number of scale factors>)
ck::index_t G = weights_lengths[0];
ck::index_t N = output_lengths[1];
ck::index_t K = weights_lengths[1];
ck::index_t C = weights_lengths[2];
return G * N * C *
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
(static_cast<std::size_t>(2) * K *
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) +
ds_size);
}
template <typename InDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetInputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& input_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * std::accumulate(std::begin(input_lengths),
std::end(input_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename WeiDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetWeightByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * std::accumulate(std::begin(weights_lengths),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
template <typename OutDataType, ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetOutputByte(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * std::accumulate(std::begin(output_lengths),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>());
}
template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
ck::index_t NumNonSpatialDim = 3,
typename AComputeType = InDataType,
typename BComputeType = AComputeType>
bool run_grouped_conv_fwd_convscale(
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_lengths,
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_lengths)
{
std::size_t in_mem_size = GetInputByte<InDataType, NumDimSpatial>(in_lengths);
std::size_t wei_mem_size = GetWeightByte<WeiDataType, NumDimSpatial>(wei_lengths);
std::size_t out_mem_size = GetOutputByte<OutDataType, NumDimSpatial>(out_lengths);
SimpleDeviceMem in(in_mem_size);
SimpleDeviceMem wei(wei_mem_size);
SimpleDeviceMem out(out_mem_size);
float scale_in;
float scale_wei;
float scale_out;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> in_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> wei_strides;
std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim> out_strides;
in_strides.fill(0);
wei_strides.fill(0);
out_strides.fill(0);
in_strides.back() = 1;
wei_strides.back() = 1;
out_strides.back() = 1;
std::partial_sum(rbegin(in_lengths),
std::prev(rend(in_lengths)),
std::next(rbegin(in_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(wei_lengths),
std::prev(rend(wei_lengths)),
std::next(rbegin(wei_strides)),
std::multiplies<>{});
std::partial_sum(rbegin(out_lengths),
std::prev(rend(out_lengths)),
std::next(rbegin(out_strides)),
std::multiplies<>{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std::rotate(std::next(rbegin(in_lengths)), std::next(rbegin(in_lengths), 2), rend(in_lengths));
std::rotate(rbegin(in_lengths),
std::next(rbegin(in_lengths)),
std::next(rbegin(in_lengths), NumDimSpatial + 1));
std::rotate(std::next(rbegin(in_strides)), std::next(rbegin(in_strides), 2), rend(in_strides));
std::rotate(rbegin(in_strides),
std::next(rbegin(in_strides)),
std::next(rbegin(in_strides), NumDimSpatial + 1));
std::rotate(rbegin(wei_lengths),
std::next(rbegin(wei_lengths)),
std::next(rbegin(wei_lengths), NumDimSpatial + 1));
std::rotate(rbegin(wei_strides),
std::next(rbegin(wei_strides)),
std::next(rbegin(wei_strides), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_lengths)), std::next(rbegin(out_lengths), 2), rend(out_lengths));
std::rotate(rbegin(out_lengths),
std::next(rbegin(out_lengths)),
std::next(rbegin(out_lengths), NumDimSpatial + 1));
std::rotate(
std::next(rbegin(out_strides)), std::next(rbegin(out_strides), 2), rend(out_strides));
std::rotate(rbegin(out_strides),
std::next(rbegin(out_strides)),
std::next(rbegin(out_strides), NumDimSpatial + 1));
std::array<ck::index_t, NumDimSpatial> conv_filter_strides;
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations;
std::array<ck::index_t, NumDimSpatial> input_left_pads;
std::array<ck::index_t, NumDimSpatial> input_right_pads;
conv_filter_strides.fill(1);
conv_filter_dilations.fill(1);
input_left_pads.fill(1);
input_right_pads.fill(1);
std::size_t ds_size = 3; // 3 element-wise scale multipliers
std::size_t flop = GetFlops<NumDimSpatial>(out_lengths, wei_lengths, ds_size);
std::size_t num_bytes =
in_mem_size + wei_mem_size + sizeof(float) + sizeof(float) + sizeof(float) + out_mem_size;
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD<NumDimSpatial,
InLayout,
WeiLayout,
ck::Tuple<>,
OutLayout,
InDataType,
WeiDataType,
ck::Tuple<>,
OutDataType,
PassThrough,
PassThrough,
ConvScale,
AComputeType,
BComputeType>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_avg_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
float best_tflops = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScale{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time;
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_tflops = tflops;
}
}
else
{
std::cerr << op_name << " does not support this problem" << std::endl;
}
}
if(best_op_id < 0)
{
std::cerr << "no suitable instance" << std::endl;
return false;
}
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
<< " TFlops, " << best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
in.GetDeviceBuffer(),
wei.GetDeviceBuffer(),
std::array<const void*, 0>{},
out.GetDeviceBuffer(),
in_lengths,
in_strides,
wei_lengths,
wei_strides,
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
std::array<std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>, 0>{},
out_lengths,
out_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
PassThrough{},
PassThrough{},
ConvScale{scale_in, scale_wei, scale_out});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using CShuffleDataType = float;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
using InLayout = ck::tensor_layout::convolution::NDHWGC;
using WeiLayout = ck::tensor_layout::convolution::GKZYXC;
using OutLayout = ck::tensor_layout::convolution::NDHWGK;
static constexpr ck::index_t NumDimSpatial = 3;
static constexpr ck::index_t G = 1;
static constexpr ck::index_t N = 64;
static constexpr ck::index_t K = 128;
static constexpr ck::index_t C = 64;
static constexpr ck::index_t Z = 3;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Di = 28;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 3;
static constexpr ck::index_t Do = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 3;
int main()
{
return run_grouped_conv_fwd_convscale<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout,
3,
AComputeDataType,
BComputeDataType>(
{N, Di, Hi, Wi, G, C}, {G, K, Z, Y, X, C}, {N, Do, Ho, Wo, G, K})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
rocm-docs-core==1.1.2 rocm-docs-core==1.3.0
sphinxcontrib-bibtex==2.6.2 sphinxcontrib-bibtex==2.6.2
...@@ -103,7 +103,7 @@ requests==2.31.0 ...@@ -103,7 +103,7 @@ requests==2.31.0
# via # via
# pygithub # pygithub
# sphinx # sphinx
rocm-docs-core==1.1.2 rocm-docs-core==1.3.0
# via -r requirements.in # via -r requirements.in
six==1.16.0 six==1.16.0
# via # via
......
add_subdirectory(binary) add_subdirectory(binary)
add_subdirectory(convscale)
add_subdirectory(multi_AB) add_subdirectory(multi_AB)
add_subdirectory(unary) add_subdirectory(unary)
......
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_convnd_activ_xdl_convscale)
add_example_executable(example_convnd_fwd_xdl_convscale_fp8 convnd_fwd_xdl_convscale_fp8.cpp)
add_example_dependencies(example_convnd_activ_xdl_convscale example_convnd_fwd_xdl_convscale_fp8)
set(target 1)
endif()
endforeach()
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <type_traits>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/algorithm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp"
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ConvScale = ck::tensor_operation::element_wise::ConvScale;
void print_helper_msg()
{
std::cout << "arg1: verification (0=no, 1=yes)\n"
<< "arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"
<< "arg3: time kernel (0=no, 1=yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
}
template <typename DataType>
inline __host__ __device__ constexpr double get_rtol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 1e-1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 1.5e-1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <typename DataType>
inline __host__ __device__ constexpr double get_atol()
{
if constexpr(std::is_same_v<DataType, float>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, double>)
{
return 1e-6;
}
else if constexpr(std::is_same_v<DataType, ck::half_t>)
{
return 1e-3;
}
else if constexpr(std::is_same_v<DataType, ck::bhalf_t>)
{
return 5e-2;
}
else if constexpr(std::is_same_v<DataType, int32_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, int8_t>)
{
return 1e-1;
}
else if constexpr(std::is_same_v<DataType, ck::f8_t>)
{
return 16.1; // 240 and 224 are acceptable
}
else if constexpr(std::is_same_v<DataType, ck::bf8_t>)
{
return 8192.1; // 57344 and 49152 are acceptable
}
else
{
return 1e-3;
}
}
template <ck::index_t NumDimSpatial, ck::index_t NumNonSpatialDim = 3>
std::size_t
GetFlops(const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& output_lengths,
const std::array<ck::index_t, NumDimSpatial + NumNonSpatialDim>& weights_lengths,
const std::size_t& ds_size)
{
// G * N * C * <output spatial lengths product> * (2 * K * <filter spatial lengths product> +
// <number of scale factors>)
ck::index_t G = weights_lengths[0];
ck::index_t N = output_lengths[1];
ck::index_t K = weights_lengths[1];
ck::index_t C = weights_lengths[2];
return G * N * C *
std::accumulate(std::next(std::begin(output_lengths), NumNonSpatialDim),
std::end(output_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
(static_cast<std::size_t>(2) * K *
std::accumulate(std::next(std::begin(weights_lengths), NumNonSpatialDim),
std::end(weights_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) +
ds_size);
}
template <ck::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename CShuffleDataType,
typename DsDataType,
typename OutDataType,
typename InElementOp,
typename WeiElementOp,
typename OutElementOp,
typename DeviceConvNDFwdInstance>
bool run_grouped_conv_fwd(bool do_verification,
int init_method,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const HostTensorDescriptor& in_g_n_c_wis_desc,
const HostTensorDescriptor& wei_g_k_c_xs_desc,
const HostTensorDescriptor& out_g_n_k_wos_desc,
const InElementOp& in_element_op,
const WeiElementOp& wei_element_op)
{
Tensor<InDataType> in(in_g_n_c_wis_desc);
Tensor<WeiDataType> wei(wei_g_k_c_xs_desc);
Tensor<CShuffleDataType> c(out_g_n_k_wos_desc);
Tensor<OutDataType> out_host(out_g_n_k_wos_desc);
Tensor<OutDataType> out_device(out_g_n_k_wos_desc);
std::cout << "in: " << in.mDesc << std::endl;
std::cout << "wei: " << wei.mDesc << std::endl;
std::cout << "out: " << out_host.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
in.GenerateTensorValue(GeneratorTensor_2<InDataType>{-5, 5});
wei.GenerateTensorValue(GeneratorTensor_2<WeiDataType>{-5, 5});
break;
default:
in.GenerateTensorValue(GeneratorTensor_3<InDataType>{0.0, 1.0});
wei.GenerateTensorValue(GeneratorTensor_3<WeiDataType>{-0.5, 0.5});
}
DeviceMem in_device_buf(sizeof(InDataType) * in.mDesc.GetElementSpaceSize());
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei.mDesc.GetElementSpaceSize());
DeviceMem out_device_buf(sizeof(OutDataType) * out_device.mDesc.GetElementSpaceSize());
in_device_buf.ToDevice(in.mData.data());
wei_device_buf.ToDevice(wei.mData.data());
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{};
std::array<ck::index_t, NDimSpatial + 3> a_g_n_c_wis_strides{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{};
std::array<ck::index_t, NDimSpatial + 3> b_g_k_c_xs_strides{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_lengths{};
std::array<ck::index_t, NDimSpatial + 3> e_g_n_k_wos_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
std::array<ck::index_t, NDimSpatial> input_left_pads{};
std::array<ck::index_t, NDimSpatial> input_right_pads{};
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths);
copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides);
copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths);
copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides);
copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths);
copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides);
copy(conv_param.conv_filter_strides_, conv_filter_strides);
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
copy(conv_param.input_left_pads_, input_left_pads);
copy(conv_param.input_right_pads_, input_right_pads);
// random scale values
float scale_in = float(std::rand()) / float(RAND_MAX);
float scale_wei = float(std::rand()) / float(RAND_MAX);
float scale_out = float(std::rand()) / float(RAND_MAX);
// initialize out_element_op for each iteration
const auto out_element_op = OutElementOp{scale_in, scale_wei, scale_out};
// do Conv
auto conv = DeviceConvNDFwdInstance{};
auto invoker = conv.MakeInvoker();
auto argument = conv.MakeArgument(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
std::array<const void*, 0>{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
std::array<std::array<ck::index_t, NDimSpatial + 3>, 0>{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
if(!conv.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem");
}
float avg_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t ds_size = 3; // 3 element-wise scale multipliers
std::size_t flop = GetFlops<NDimSpatial>(e_g_n_k_wos_lengths, b_g_k_c_xs_lengths, ds_size);
std::size_t num_btype = conv_param.GetInputByte<InDataType>() +
conv_param.GetWeightByte<WeiDataType>() + sizeof(float) +
sizeof(float) + sizeof(float) + conv_param.GetOutputByte<OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_btype / 1.E6 / avg_time;
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< conv.GetTypeString() << std::endl;
if(do_verification)
{
auto ref_conv = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,
InDataType,
WeiDataType,
CShuffleDataType,
InElementOp,
WeiElementOp,
PassThrough>();
auto ref_invoker = ref_conv.MakeInvoker();
auto ref_argument = ref_conv.MakeArgument(in,
wei,
c,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_,
in_element_op,
wei_element_op,
PassThrough{});
ref_invoker.Run(ref_argument);
out_host.ForEach([&](auto&, auto idx) { out_element_op(out_host(idx), c(idx)); });
out_device_buf.FromDevice(out_device.mData.data());
return ck::utils::check_err(out_device,
out_host,
"Error: incorrect results!",
get_rtol<OutDataType>(),
get_atol<OutDataType>());
}
return true;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "convnd_fwd_convscale_common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp"
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
using InDataType = ck::f8_t;
using WeiDataType = ck::f8_t;
using AccDataType = float;
using CShuffleDataType = float;
using DsDataType = ck::Tuple<>;
using OutDataType = ck::f8_t;
using AComputeDataType = ck::f8_t;
using BComputeDataType = ck::f8_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ConvScale;
static constexpr auto ConvSpec =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
template <ck::index_t NDimSpatial,
typename InLayout,
typename WeiLayout,
typename DsLayout,
typename OutLayout>
using DeviceGroupedConvNDFwdInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial,
InLayout,
WeiLayout,
DsLayout,
OutLayout,
InDataType,
WeiDataType,
AccDataType,
CShuffleDataType,
DsDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
ConvSpec, // ConvForwardSpecialization
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXdl
32, // NPerXdl
2, // MXdlPerWave
4, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
8,
AComputeDataType,
BComputeDataType>;
#include "run_convnd_fwd_convscale_example.inc"
int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; }
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
bool run_convnd_fwd_example(int argc, char* argv[])
{
print_helper_msg();
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
ck::utils::conv::ConvParam conv_param{
2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1}};
if(argc == 1)
{
// use default
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(num_dim_spatial, 5, argv);
}
// instantiate in and wei element ops, will
// instantiate out_element_op below for every iteration
const auto in_element_op = InElementOp{};
const auto wei_element_op = WeiElementOp{};
const auto run =
[&](auto ndim_spatial, auto in_layout, auto wei_layout, auto ds_layout, auto out_layout) {
constexpr ck::index_t ndim_spatial_value = ndim_spatial.value;
using InLayout = decltype(in_layout);
using WeiLayout = decltype(wei_layout);
using DsLayout = decltype(ds_layout);
using OutLayout = decltype(out_layout);
const auto in_g_n_c_wis_desc =
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
conv_param);
const auto wei_g_k_c_xs_desc =
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
conv_param);
const auto out_g_n_k_wos_desc =
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
conv_param);
return run_grouped_conv_fwd<ndim_spatial_value,
InDataType,
WeiDataType,
CShuffleDataType,
DsDataType,
OutDataType,
InElementOp,
WeiElementOp,
OutElementOp,
DeviceGroupedConvNDFwdInstance<ndim_spatial_value,
InLayout,
WeiLayout,
DsLayout,
OutLayout>>(
do_verification,
init_method,
time_kernel,
conv_param,
in_g_n_c_wis_desc,
wei_g_k_c_xs_desc,
out_g_n_k_wos_desc,
in_element_op,
wei_element_op);
};
namespace ctc = ck::tensor_layout::convolution;
if(conv_param.num_dim_spatial_ == 1)
{
return run(ck::Number<1>{}, ctc::GNWC{}, ctc::GKXC{}, ck::Tuple<>{}, ctc::GNWK{});
}
else if(conv_param.num_dim_spatial_ == 2)
{
return run(ck::Number<2>{}, ctc::GNHWC{}, ctc::GKYXC{}, ck::Tuple<>{}, ctc::GNHWK{});
}
else if(conv_param.num_dim_spatial_ == 3)
{
return run(ck::Number<3>{}, ctc::GNDHWC{}, ctc::GKZYXC{}, ck::Tuple<>{}, ctc::GNDHWK{});
}
return true;
}
add_example_executable(example_gemm_multiply_multiply_xdl_fp16 gemm_multiply_multiply_xdl_fp16.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = FP8;
using B0DataType = FP8;
using AccDataType = F32;
using CShuffleDataType = F32;
using D0DataType = F32;
using D1DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16;
using A0Layout = Row;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
using DsLayout = ck::Tuple<D0Layout, D1Layout>;
using ELayout = Row;
struct MultiplyMultiply
{
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::half_t>(x0_f);
}
};
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MultiplyMultiply;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3
// clang-format off
///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
///###### RCR
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = false;
// GEMM shape
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD = 0;
ck::index_t StrideE = N;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 11)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
M = std::stoi(argv[4]);
N = std::stoi(argv[5]);
K = std::stoi(argv[6]);
StrideA = std::stoi(argv[7]);
StrideB = std::stoi(argv[8]);
StrideD = std::stoi(argv[9]);
StrideE = std::stoi(argv[10]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 9: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE\n");
exit(0);
}
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(std::is_same<decltype(layout), ck::tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl;
std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_k_n.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto I0 = ck::Number<0>{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{I0, I0},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s"
<< std::endl;
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
if(do_verification)
{
Tensor<CShuffleDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<A0DataType,
B0DataType,
CShuffleDataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(
a0_m_k, b0_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1;
}
return 0;
}
# generate a list of kernels, but not actually emit files at config stage # generate a list of kernels, but not actually emit files at config stage
execute_process( execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt --direction fwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
) )
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS files must be in the same directory execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
)
# NOTE: for cmake, the FMHA_FWD_GEN_BLOBS/FMHA_BWD_GEN_BLOBS files must be in the same directory
# as current cmake list, otherwise will not figure out the dependency properly # as current cmake list, otherwise will not figure out the dependency properly
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/blob_list.txt FMHA_FWD_GEN_BLOBS) file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command( add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS} OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--output_dir ${CMAKE_CURRENT_BINARY_DIR} --direction fwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--direction bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR}
) )
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd") set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
...@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp) ...@@ -22,6 +34,14 @@ add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS}) target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_FMHA_BWD}")
add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL fmha_bwd.cpp)
target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# NOTE: this is dangerous since will change the whole kernel to flush denormals # NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this # WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2) if(NOT DEFINED FMHA_FWD_FAST_EXP2)
...@@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2) ...@@ -29,16 +49,27 @@ if(NOT DEFINED FMHA_FWD_FAST_EXP2)
endif() endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS) set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated # ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2) if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else() else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif() endif()
# Allow comparing floating points directly in order to check sentinel values # Allow comparing floating points directly in order to check sentinel values
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal) list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
...@@ -34,6 +34,7 @@ args: ...@@ -34,6 +34,7 @@ args:
if not equal to h, then this is GQA/MQA case if not equal to h, then this is GQA/MQA case
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328) -s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
-s_k seqlen_k, -1 means equal to s (default:-1) -s_k seqlen_k, -1 means equal to s (default:-1)
-d head dim for q, k (default:128) -d head dim for q, k (default:128)
-d_v head dim for v, -1 means equal to d (default:-1) -d_v head dim for v, -1 means equal to d (default:-1)
......
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/fmha.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "mask.hpp"
#include "bias.hpp"
#include <type_traits>
template <typename DataType>
struct FmhaBwdTypeConfig;
template <>
struct FmhaBwdTypeConfig<ck_tile::half_t>
{
using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t;
using GemmDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::half_t;
using OGradDataType = ck_tile::half_t;
using QGradDataType = ck_tile::half_t;
using KGradDataType = ck_tile::half_t;
using VGradDataType = ck_tile::half_t;
using BiasGradDataType = ck_tile::half_t;
};
template <>
struct FmhaBwdTypeConfig<ck_tile::bf16_t>
{
using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t;
using GemmDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float;
using AccDataType = float; // data type for gemm accumulation
using DDataType = float;
using RandValOutputDataType = uint8_t;
using ODataType = ck_tile::bf16_t;
using OGradDataType = ck_tile::bf16_t;
using QGradDataType = ck_tile::bf16_t;
using KGradDataType = ck_tile::bf16_t;
using VGradDataType = ck_tile::bf16_t;
using BiasGradDataType = ck_tile::bf16_t;
};
struct FmhaMasks
{
using NoMask = ck_tile::GenericAttentionMask<false>;
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_bwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
const void* o_ptr;
const void* lse_ptr;
const void* do_ptr;
void* d_ptr;
void* rand_val_ptr;
void* dq_ptr;
void* dk_ptr;
void* dv_ptr;
void* dbias_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t max_seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t stride_randval;
ck_tile::index_t stride_do;
ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv;
ck_tile::index_t stride_dbias;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dbias;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv;
ck_tile::index_t batch_stride_dbias;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
float p_drop;
float p_undrop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
};
template <typename FmhaBwdDQDKDVKernel>
auto fmha_bwd_dq_dk_dv_create_kargs_and_grids(fmha_bwd_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdDQDKDVKernel::kIsGroupMode)
{
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_lsed,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaBwdDQDKDVKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.do_ptr,
args.d_ptr,
args.rand_val_ptr,
args.dq_ptr,
args.dk_ptr,
args.dv_ptr,
args.dbias_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_do,
args.stride_dk,
args.stride_dv,
args.stride_dbias,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_do,
args.nhead_stride_lsed,
args.nhead_stride_dbias,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_do,
args.batch_stride_lsed,
args.batch_stride_dk,
args.batch_stride_dv,
args.batch_stride_dbias,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
dim3 grids = FmhaBwdDQDKDVKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_k);
return ck_tile::make_tuple(kargs, grids);
}
template <typename FmhaBwdOGradDotOKernel>
auto fmha_bwd_dot_do_o_create_kargs_and_grids(fmha_bwd_args args)
{
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaBwdOGradDotOKernel::kIsGroupMode)
{
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqstart_q_ptr,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_lsed);
}
else
{ // create batch mode kernel arguments
return FmhaBwdOGradDotOKernel::MakeKargs(args.o_ptr,
args.do_ptr,
args.d_ptr,
args.p_undrop,
args.seqlen_q,
args.hdim_v,
args.stride_do,
args.stride_o,
args.nhead_stride_do,
args.nhead_stride_o,
args.nhead_stride_lsed,
args.batch_stride_do,
args.batch_stride_o,
args.batch_stride_lsed);
}
}();
dim3 grids = FmhaBwdOGradDotOKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kHasDropout_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_>
struct fmha_bwd_dq_dk_dv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dq_dk_dv_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dq_dk_dv_get_name_();
template <ck_tile::index_t HDim_, typename DataType_, bool kIsGroupMode_, bool kPadS_, bool kPadDv_>
struct fmha_bwd_dot_do_o_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadDv = kPadDv_;
};
template <typename Traits_>
float fmha_bwd_dot_do_o_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config&, fmha_bwd_args);
template <typename Traits_>
std::string fmha_bwd_dot_do_o_get_name_();
// This is the public API, will be generated by script
struct fmha_bwd_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_dbias;
bool has_dropout;
// TODO: padding check is inside this api
};
float fmha_bwd(fmha_bwd_traits, fmha_bwd_args, const ck_tile::stream_config&);
This diff is collapsed.
...@@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig; ...@@ -17,61 +17,65 @@ struct FmhaFwdTypeConfig;
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::half_t> struct FmhaFwdTypeConfig<ck_tile::half_t>
{ {
using QDataType = ck_tile::half_t; using QDataType = ck_tile::half_t;
using KDataType = ck_tile::half_t; using KDataType = ck_tile::half_t;
using VDataType = ck_tile::half_t; using VDataType = ck_tile::half_t;
using BiasDataType = ck_tile::half_t; using BiasDataType = ck_tile::half_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::half_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::half_t; // data type for A matrix of second gemm
using ODataType = ck_tile::half_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::half_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf16_t> struct FmhaFwdTypeConfig<ck_tile::bf16_t>
{ {
using QDataType = ck_tile::bf16_t; using QDataType = ck_tile::bf16_t;
using KDataType = ck_tile::bf16_t; using KDataType = ck_tile::bf16_t;
using VDataType = ck_tile::bf16_t; using VDataType = ck_tile::bf16_t;
using BiasDataType = ck_tile::bf16_t; using BiasDataType = ck_tile::bf16_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::bf16_t; // data type for A matrix of second gemm
using ODataType = ck_tile::bf16_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf16_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::fp8_t> struct FmhaFwdTypeConfig<ck_tile::fp8_t>
{ {
using QDataType = ck_tile::fp8_t; using QDataType = ck_tile::fp8_t;
using KDataType = ck_tile::fp8_t; using KDataType = ck_tile::fp8_t;
using VDataType = ck_tile::fp8_t; using VDataType = ck_tile::fp8_t;
using BiasDataType = float; using BiasDataType = float;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm
using ODataType = ck_tile::fp8_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::fp8_t;
}; };
template <> template <>
struct FmhaFwdTypeConfig<ck_tile::bf8_t> struct FmhaFwdTypeConfig<ck_tile::bf8_t>
{ {
using QDataType = ck_tile::bf8_t; using QDataType = ck_tile::bf8_t;
using KDataType = ck_tile::bf8_t; using KDataType = ck_tile::bf8_t;
using VDataType = ck_tile::bf8_t; using VDataType = ck_tile::bf8_t;
using BiasDataType = ck_tile::bf8_t; using BiasDataType = ck_tile::bf8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) using RandValOutputDataType = uint8_t;
using SaccDataType = float; // data type for first gemm accumulation using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SMPLComputeDataType = float; // data type for reduction, softmax using SaccDataType = float; // data type for first gemm accumulation
using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm using SMPLComputeDataType = float; // data type for reduction, softmax
using OaccDataType = float; // data type for second gemm accumulation using PDataType = ck_tile::bf8_t; // data type for A matrix of second gemm
using ODataType = ck_tile::bf8_t; using OaccDataType = float; // data type for second gemm accumulation
using ODataType = ck_tile::bf8_t;
}; };
struct FmhaMasks struct FmhaMasks
...@@ -88,6 +92,7 @@ struct fmha_fwd_args ...@@ -88,6 +92,7 @@ struct fmha_fwd_args
const void* k_ptr; const void* k_ptr;
const void* v_ptr; const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer const void* bias_ptr; // bias or alibi_slope pointer
void* rand_val_ptr;
void* lse_ptr; void* lse_ptr;
void* o_ptr; void* o_ptr;
const void* seqstart_q_ptr; const void* seqstart_q_ptr;
...@@ -108,22 +113,28 @@ struct fmha_fwd_args ...@@ -108,22 +113,28 @@ struct fmha_fwd_args
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0 ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_randval;
ck_tile::index_t stride_o; ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q; ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k; ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias; ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_randval;
ck_tile::index_t nhead_stride_lse; ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias; ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_randval;
ck_tile::index_t batch_stride_lse; ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left; ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right; ck_tile::index_t window_size_right;
ck_tile::index_t mask_type; ck_tile::index_t mask_type;
float p_drop;
bool s_randval;
std::tuple<uint64_t, uint64_t> drop_seed_offset;
}; };
template <typename FmhaKernel> template <typename FmhaKernel>
...@@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -138,6 +149,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqstart_q_ptr, args.seqstart_q_ptr,
...@@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -145,6 +157,7 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.seqlen_k_ptr, args.seqlen_k_ptr,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
...@@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -153,16 +166,22 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_lse,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
else else
{ // create batch mode kernel arguments { // create batch mode kernel arguments
...@@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -170,12 +189,14 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.k_ptr, args.k_ptr,
args.v_ptr, args.v_ptr,
args.bias_ptr, args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr, args.lse_ptr,
args.o_ptr, args.o_ptr,
args.seqlen_q, args.seqlen_q,
args.seqlen_k, args.seqlen_k,
args.hdim_q, args.hdim_q,
args.hdim_v, args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k, args.nhead_q / args.nhead_k,
args.scale_s, args.scale_s,
args.scale_p, args.scale_p,
...@@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) ...@@ -184,22 +205,28 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
args.stride_k, args.stride_k,
args.stride_v, args.stride_v,
args.stride_bias, args.stride_bias,
args.stride_randval,
args.stride_o, args.stride_o,
args.nhead_stride_q, args.nhead_stride_q,
args.nhead_stride_k, args.nhead_stride_k,
args.nhead_stride_v, args.nhead_stride_v,
args.nhead_stride_bias, args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse, args.nhead_stride_lse,
args.nhead_stride_o, args.nhead_stride_o,
args.batch_stride_q, args.batch_stride_q,
args.batch_stride_k, args.batch_stride_k,
args.batch_stride_v, args.batch_stride_v,
args.batch_stride_bias, args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse, args.batch_stride_lse,
args.batch_stride_o, args.batch_stride_o,
args.window_size_left, args.window_size_left,
args.window_size_right, args.window_size_right,
args.mask_type); args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
} }
}(); }();
...@@ -222,6 +249,7 @@ template <ck_tile::index_t HDim_, ...@@ -222,6 +249,7 @@ template <ck_tile::index_t HDim_,
typename FmhaMask_, typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_, ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_, bool kStoreLse_,
bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
bool kPadS_, bool kPadS_,
bool kPadSK_, bool kPadSK_,
...@@ -243,6 +271,7 @@ struct fmha_fwd_traits_ ...@@ -243,6 +271,7 @@ struct fmha_fwd_traits_
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>; using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_; static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_; static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_; static constexpr bool kPadSK = kPadSK_;
...@@ -264,6 +293,7 @@ struct fmha_fwd_traits ...@@ -264,6 +293,7 @@ struct fmha_fwd_traits
mask_enum mask_type; mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse; bool has_lse;
bool has_dropout;
bool do_fp8_static_quant; bool do_fp8_static_quant;
// TODO: padding check is inside this api // TODO: padding check is inside this api
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment