"vscode:/vscode.git/clone" did not exist on "fbc89ca7078386d3a73afd390fdf20c7cec0686e"
Commit ac876c6f authored by Chao Liu's avatar Chao Liu
Browse files

gemm bilinear

parent 8ec065fd
add_example_executable(example_gemm_alpha_beta_xdl_fp16 gemm_alpha_beta_xdl_fp16.cpp)
add_example_executable(example_gemm_bilinear_xdl_fp16 gemm_bilinear_xdl_fp16.cpp)
# Instructions for ```example_gemm_xdl_alpha_beta``` # Instructions for ```example_gemm_bilinear_xdl_fp16```
## Run ```example_gemm_xdl_alpha_beta``` ## Run ```example_gemm_bilinear_xdl_fp16```
```bash ```bash
#arg1: verification (0=no, 1=yes) #arg1: verification (0=no, 1=yes)
#arg2: initialization (0=no init, 1=integer value, 2=decimal value) #arg2: initialization (0=no init, 1=integer value, 2=decimal value)
#arg3: run kernel # of times (>1) #arg3: time kernel (0=no, 1=yes)
./bin/example_gemm_xdl_alpha_beta 1 1 1 0.5 0.5 #arg4 to 10: M (256x), N(128x), K(32x), StrideA, StrideB, StrideD, StrideE
#arg11 to 12: alpha, beta
./bin/example_gemm_bilinear_xdl_fp16 1 1 1 3840 4096 4096 4096 4096 4096 4096 0.5 0.5
``` ```
Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16) Result (MI100 @ 1502Mhz, 184.6TFlops peak FP16)
``` ```
......
...@@ -22,7 +22,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -22,7 +22,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endfunction(add_example_executable_no_testing EXAMPLE_NAME) endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory(01_gemm) add_subdirectory(01_gemm)
add_subdirectory(02_gemm_alpha_beta) add_subdirectory(02_gemm_bilinear)
add_subdirectory(03_gemm_bias_relu) add_subdirectory(03_gemm_bias_relu)
add_subdirectory(04_gemm_add_add_fastgelu) add_subdirectory(04_gemm_add_add_fastgelu)
add_subdirectory(06_conv2d_fwd_bias_relu) add_subdirectory(06_conv2d_fwd_bias_relu)
......
...@@ -23,7 +23,8 @@ template <typename ALayout, ...@@ -23,7 +23,8 @@ template <typename ALayout,
typename CElementwiseOperation> typename CElementwiseOperation>
struct DeviceBatchedGemm : public BaseOperator struct DeviceBatchedGemm : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b, const void* p_b,
void* p_c, void* p_c,
ck::index_t M, ck::index_t M,
...@@ -35,10 +36,10 @@ struct DeviceBatchedGemm : public BaseOperator ...@@ -35,10 +36,10 @@ struct DeviceBatchedGemm : public BaseOperator
ck::index_t BatchStrideA, ck::index_t BatchStrideA,
ck::index_t BatchStrideB, ck::index_t BatchStrideB,
ck::index_t BatchStrideC, ck::index_t BatchStrideC,
ck::index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) = 0;
ck::index_t Batch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
......
...@@ -344,12 +344,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -344,12 +344,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC, index_t BatchStrideC,
index_t Batch,
index_t M01, index_t M01,
index_t N01, index_t N01,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t Batch)
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
...@@ -546,10 +546,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -546,10 +546,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC, index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op)
index_t Batch)
{ {
return Argument{p_a, return Argument{p_a,
p_b, p_b,
...@@ -563,12 +563,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -563,12 +563,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideA, BatchStrideA,
BatchStrideB, BatchStrideB,
BatchStrideC, BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op};
Batch};
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -586,10 +586,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -586,10 +586,10 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
index_t BatchStrideA, index_t BatchStrideA,
index_t BatchStrideB, index_t BatchStrideB,
index_t BatchStrideC, index_t BatchStrideC,
index_t Batch,
AElementwiseOperation a_element_op, AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op, BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op, CElementwiseOperation c_element_op) override
index_t Batch) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b), static_cast<const BDataType*>(p_b),
...@@ -603,12 +603,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout, ...@@ -603,12 +603,12 @@ struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
BatchStrideA, BatchStrideA,
BatchStrideB, BatchStrideB,
BatchStrideC, BatchStrideC,
Batch,
1, 1,
1, 1,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op);
Batch);
} }
// polymorphic // polymorphic
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBias : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
const void* p_bias,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasPtr = std::unique_ptr<
DeviceGemmBias<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBiasActivation : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasActivationPtr = std::unique_ptr<
DeviceGemmBiasActivation<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#define DEVICE_GEMM_BIAS_ACTIVATION_ADD_HPP
#include <iostream>
#include "device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmBiasActivationAdd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const void* p_c0,
const void* p_c1,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideB,
ck::index_t StrideC,
ck::index_t StrideC1,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
using DeviceGemmBiasActivationAddPtr =
std::unique_ptr<DeviceGemmBiasActivationAdd<AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
...@@ -94,10 +94,11 @@ struct Subtract ...@@ -94,10 +94,11 @@ struct Subtract
} }
}; };
struct AlphaBetaAdd struct Bilinear
{ {
AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
#if 0
template <typename T> template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
...@@ -115,7 +116,6 @@ struct AlphaBetaAdd ...@@ -115,7 +116,6 @@ struct AlphaBetaAdd
y = type_convert<double>(alpha_) * x0 + type_convert<double>(beta_) * x1; y = type_convert<double>(alpha_) * x0 + type_convert<double>(beta_) * x1;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -123,6 +123,24 @@ struct AlphaBetaAdd ...@@ -123,6 +123,24 @@ struct AlphaBetaAdd
y = type_convert<half_t>(alpha_ * type_convert<float>(x0) + y = type_convert<half_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1)); beta_ * type_convert<float>(x1));
}; };
#else
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
};
#endif
float alpha_; float alpha_;
float beta_; float beta_;
......
...@@ -16,12 +16,14 @@ using F32 = float; ...@@ -16,12 +16,14 @@ using F32 = float;
using F16 = ck::half_t; using F16 = ck::half_t;
using BF16 = ck::bhalf_t; using BF16 = ck::bhalf_t;
using F16_F16 = ck::Tuple<F16, F16>; using F16_TUPLE = ck::Tuple<F16>;
using F16_F16_TUPLE = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu; using AddAddFastGelu = ck::tensor_operation::element_wise::AddAddFastGelu;
template <typename DeviceOp> template <typename DeviceOp>
......
...@@ -25,7 +25,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instanc ...@@ -25,7 +25,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_TUPLE,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -37,7 +37,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instanc ...@@ -37,7 +37,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_TUPLE,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -49,7 +49,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc ...@@ -49,7 +49,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_TUPLE,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -61,7 +61,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc ...@@ -61,7 +61,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_TUPLE,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
...@@ -73,7 +73,8 @@ template <typename ALayout, ...@@ -73,7 +73,8 @@ template <typename ALayout,
typename DELayout, typename DELayout,
typename ADataType, typename ADataType,
typename BDataType, typename BDataType,
typename DsDataType, typename D0DataType,
typename D1DataType,
typename EDataType> typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD< struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout, ALayout,
...@@ -81,7 +82,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -81,7 +82,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
DELayout, DELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -92,7 +93,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -92,7 +93,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
DELayout, DELayout,
ADataType, ADataType,
BDataType, BDataType,
DsDataType, ck::Tuple<D0DataType, D1DataType>,
EDataType, EDataType,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
...@@ -103,7 +104,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu ...@@ -103,7 +104,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
std::vector<std::unique_ptr<DeviceOp>> op_ptrs; std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> && if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DsDataType, Tuple<half_t, half_t>> && is_same_v<EDataType, half_t>) is_same_v<D0DataType, half_t> && is_same_v<D1DataType, half_t> &&
is_same_v<EDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> && if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>) is_same_v<DELayout, Row>)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleD<Col,
Row,
Row,
F16,
F16,
F16_TUPLE,
F16,
PassThrough,
PassThrough,
Bilinear>>>& instances);
// GEMM + Bilinear
template <typename ALayout,
typename BLayout,
typename DELayout,
typename ADataType,
typename BDataType,
typename DDataType,
typename EDataType>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleD<
ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>>
{
using DeviceOp = DeviceGemmMultipleD<ALayout,
BLayout,
DELayout,
ADataType,
BDataType,
ck::Tuple<DDataType>,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Bilinear>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<DDataType, half_t> && is_same_v<EDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
{}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
{
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Row> &&
is_same_v<DELayout, Row>)
{
add_device_gemm_bilinear_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
else if constexpr(is_same_v<ALayout, Col> && is_same_v<BLayout, Col> &&
is_same_v<DELayout, Row>)
{
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -5,15 +5,12 @@ function(add_instance_library INSTANCE_NAME) ...@@ -5,15 +5,12 @@ function(add_instance_library INSTANCE_NAME)
set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(${INSTANCE_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endfunction(add_instance_library INSTANCE_NAME) endfunction(add_instance_library INSTANCE_NAME)
add_subdirectory(elementwise)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_splitk) add_subdirectory(gemm_splitk)
add_subdirectory(gemm_bias2d) add_subdirectory(gemm_bilinear)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce) add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(gemm_add_add_fastgelu)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(batched_gemm_reduce) add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
...@@ -25,17 +22,16 @@ add_subdirectory(conv2d_fwd_bias_relu_add) ...@@ -25,17 +22,16 @@ add_subdirectory(conv2d_fwd_bias_relu_add)
add_subdirectory(conv2d_bwd_data) add_subdirectory(conv2d_bwd_data)
add_subdirectory(convnd_bwd_data) add_subdirectory(convnd_bwd_data)
add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
add_subdirectory(normalization)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(normalization)
add_subdirectory(elementwise)
add_library(device_operations STATIC add_library(device_operations STATIC
$<TARGET_OBJECTS:device_gemm_instance> $<TARGET_OBJECTS:device_gemm_instance>
$<TARGET_OBJECTS:device_gemm_splitk_instance> $<TARGET_OBJECTS:device_gemm_splitk_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_instance> $<TARGET_OBJECTS:device_gemm_bilinear_instance>
$<TARGET_OBJECTS:device_gemm_bias_relu_add_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_gemm_bias2d_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance> $<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_gemm_bias_add_reduce_instance>
$<TARGET_OBJECTS:device_batched_gemm_instance> $<TARGET_OBJECTS:device_batched_gemm_instance>
$<TARGET_OBJECTS:device_batched_gemm_reduce_instance> $<TARGET_OBJECTS:device_batched_gemm_reduce_instance>
$<TARGET_OBJECTS:device_grouped_gemm_instance> $<TARGET_OBJECTS:device_grouped_gemm_instance>
...@@ -47,9 +43,9 @@ add_library(device_operations STATIC ...@@ -47,9 +43,9 @@ add_library(device_operations STATIC
$<TARGET_OBJECTS:device_conv2d_bwd_data_instance> $<TARGET_OBJECTS:device_conv2d_bwd_data_instance>
$<TARGET_OBJECTS:device_convnd_bwd_data_instance> $<TARGET_OBJECTS:device_convnd_bwd_data_instance>
$<TARGET_OBJECTS:device_conv2d_bwd_weight_instance> $<TARGET_OBJECTS:device_conv2d_bwd_weight_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
$<TARGET_OBJECTS:device_gemm_add_add_fastgelu_instance>
$<TARGET_OBJECTS:device_reduce_instance> $<TARGET_OBJECTS:device_reduce_instance>
$<TARGET_OBJECTS:device_normalization_instance>
$<TARGET_OBJECTS:device_elementwise_instance>
) )
add_library(composablekernels::device_operations ALIAS device_operations) add_library(composablekernels::device_operations ALIAS device_operations)
...@@ -81,7 +77,6 @@ target_include_directories(device_operations PUBLIC ...@@ -81,7 +77,6 @@ target_include_directories(device_operations PUBLIC
#once new arches are enabled make this an option on the main cmake file #once new arches are enabled make this an option on the main cmake file
# and pass down here to be exported # and pass down here to be exported
target_compile_options(device_operations PRIVATE target_compile_options(device_operations PRIVATE
--offload-arch=gfx908 --offload-arch=gfx908
--offload-arch=gfx90a --offload-arch=gfx90a
......
...@@ -16,7 +16,7 @@ namespace instance { ...@@ -16,7 +16,7 @@ namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances ...@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 2, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 2, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 2, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 2, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 2, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 2, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemmMultipleD_Xdl_CShuffle< Col, Row, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
...@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc ...@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
...@@ -16,7 +16,7 @@ namespace instance { ...@@ -16,7 +16,7 @@ namespace instance {
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
using F16_F16 = ck::Tuple<F16, F16>; using F16_F16_Tuple = ck::Tuple<F16, F16>;
using Row = ck::tensor_layout::gemm::RowMajor; using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor; using Col = ck::tensor_layout::gemm::ColumnMajor;
...@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances ...@@ -38,22 +38,22 @@ using device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
//##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //##############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //##############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 2, 8, 32, 32, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 2, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 2, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 2, 8, 32, 32, 2, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 2, 8, 32, 32, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 2, 8, 32, 32, 1, 2, S<16,16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8> DeviceGemmMultipleD_Xdl_CShuffle< Col, Col, Row, F16, F16, F32, F32, F16_F16_Tuple, F16, PassThrough, PassThrough, AddAddFastGelu, GemmDefault, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>
// clang-format on // clang-format on
>; >;
...@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc ...@@ -63,7 +63,7 @@ void add_device_gemm_add_add_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instanc
Row, Row,
F16, F16,
F16, F16,
F16_F16, F16_F16_Tuple,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment