Commit b93575ca authored by Jing Zhang's avatar Jing Zhang
Browse files

merge develop

parents 54df59bf c8a8385f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
// dinput descriptor in [N, C, Do, Ho, Wo] order
// doutput descriptor in [N, C, Di, Hi, Wi] order
// phyiscal layout is irrelavent
template <ck::index_t NDimSpatial,
typename DInDataType,
typename DOutDataType,
typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
struct ReferenceAvgPoolBwd : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
: dinput_{dinput},
doutput_{doutput},
window_spatial_lengths_{window_spatial_lengths},
window_strides_{window_strides},
window_dilations_{window_dilations},
in_left_pads_{dinput_left_pads},
in_right_pads_{dinput_right_pads}
{
}
Tensor<DInDataType>& dinput_;
const Tensor<DOutDataType>& doutput_;
std::vector<ck::index_t> window_spatial_lengths_;
std::vector<index_t> window_strides_;
std::vector<index_t> window_dilations_;
std::vector<index_t> in_left_pads_;
std::vector<index_t> in_right_pads_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceAvgPoolBwd::Argument;
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 1, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
// Let input = x, outpu = y
// shape of x = [10], y = [6]
// window_size = 5, pad = 0, stride = 1, dilation = 1
// Forward:
// y0 = 1/5 * (x0 + x1 + x2 + x3 + x4)
// y1 = 1/5 * (x1 + x2 + x3 + x4 + x5)
// ...
// y5 = 1/5 * (x5 + x6 + x7 + x8 + x9)
// y6 = 1/5 * (x6 + x7 + x8 + x9)
// ...
// y9 = 1/5 * (x9)
// Backward:
// shape of dy = [6], dx = [10]
// dx0 = 1/5 * dy0
// dx1 = 1/5 * (dy0 + dy1)
// dx2 = 1/5 * (dy0 + dy1 + dy2)
// ...
// dx4 = 1/5 * (dy0 + dy1 + dy2 + dy3 + dy4)
// dx5 = 1/5 * (dy1 + dy2 + dy3 + dy4 + dy5)
// ...
// dx9 = 1/5 * (dy5 + dy6 + dy7 + dy8 + dy9)
auto f_ncw = [&](auto n, auto c, auto wi) {
std::size_t X = arg.window_spatial_lengths_[0];
std::size_t Wo = arg.doutput_.GetLengths()[2];
float v_acc = 0;
for(std::size_t x = 0; x < X; ++x)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(w_tmp % arg.window_strides_[0] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(arg.doutput_(n, c, wo));
}
}
}
v_acc /= ck::type_convert<float>(X);
arg.dinput_(n, c, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 2, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_nchw = [&](auto n, auto c, auto hi, auto wi) {
std::size_t Y = arg.window_spatial_lengths_[0];
std::size_t X = arg.window_spatial_lengths_[1];
std::size_t Ho = arg.doutput_.GetLengths()[2];
std::size_t Wo = arg.doutput_.GetLengths()[3];
float v_acc = 0;
for(std::size_t y = 0; y < Y; ++y)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto h_tmp = static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(h_tmp % arg.window_strides_[0] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp =
static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(x * arg.window_dilations_[1]);
if(w_tmp % arg.window_strides_[1] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(wo >= 0 && ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc +=
ck::type_convert<float>(arg.doutput_(n, c, ho, wo));
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Y * X);
arg.dinput_(n, c, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_nchw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3])(
std::thread::hardware_concurrency());
return 0;
}
template <ck::index_t NDimSpatial_,
typename std::enable_if<NDimSpatial_ == 3, bool>::type = false>
float RunAvgPoolBwd(const Argument& arg)
{
auto f_ncdhw = [&](auto n, auto c, auto di, auto hi, auto wi) {
std::size_t Z = arg.window_spatial_lengths_[0];
std::size_t Y = arg.window_spatial_lengths_[1];
std::size_t X = arg.window_spatial_lengths_[2];
std::size_t Do = arg.doutput_.GetLengths()[2];
std::size_t Ho = arg.doutput_.GetLengths()[3];
std::size_t Wo = arg.doutput_.GetLengths()[4];
float v_acc = 0;
for(std::size_t z = 0; z < Z; ++z)
{
// Out_Position = (In_Position + pad - x * dilation) / stride
auto d_tmp = static_cast<ck::long_index_t>(di) +
static_cast<ck::long_index_t>(arg.in_left_pads_[0]) -
static_cast<ck::long_index_t>(z * arg.window_dilations_[0]);
// Check the input pixel validity (in perspective of being affected by some
// doutput pixel)
if(d_tmp % arg.window_strides_[0] == 0)
{
auto do_ = static_cast<ck::long_index_t>(d_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[0]);
// Get the doutput pixel in valid range to accumulate the gradients for this
// input pixel
if(do_ >= 0 && ck::type_convert<std::size_t>(do_) < Do)
{
for(std::size_t y = 0; y < Y; ++y)
{
auto h_tmp =
static_cast<ck::long_index_t>(hi) +
static_cast<ck::long_index_t>(arg.in_left_pads_[1]) -
static_cast<ck::long_index_t>(y * arg.window_dilations_[1]);
if(h_tmp % arg.window_strides_[1] == 0)
{
auto ho = static_cast<ck::long_index_t>(h_tmp) /
static_cast<ck::long_index_t>(arg.window_strides_[1]);
if(ho >= 0 && ck::type_convert<std::size_t>(ho) < Ho)
{
for(std::size_t x = 0; x < X; ++x)
{
auto w_tmp = static_cast<ck::long_index_t>(wi) +
static_cast<ck::long_index_t>(
arg.in_left_pads_[2]) -
static_cast<ck::long_index_t>(
x * arg.window_dilations_[2]);
if(w_tmp % arg.window_strides_[2] == 0)
{
auto wo = static_cast<ck::long_index_t>(w_tmp) /
static_cast<ck::long_index_t>(
arg.window_strides_[2]);
if(wo >= 0 &&
ck::type_convert<std::size_t>(wo) < Wo)
{
v_acc += ck::type_convert<float>(
arg.doutput_(n, c, do_, ho, wo));
}
}
}
}
}
}
}
}
}
v_acc /= ck::type_convert<float>(Z * Y * X);
arg.dinput_(n, c, di, hi, wi) = ck::type_convert<DInDataType>(v_acc);
};
make_ParallelTensorFunctor(f_ncdhw,
arg.dinput_.GetLengths()[0],
arg.dinput_.GetLengths()[1],
arg.dinput_.GetLengths()[2],
arg.dinput_.GetLengths()[3],
arg.dinput_.GetLengths()[4])(
std::thread::hardware_concurrency());
return 0;
}
float Run(const Argument& arg)
{
if(!(arg.dinput_.GetNumOfDimension() == NDimSpatial + 2 &&
arg.doutput_.GetNumOfDimension() == NDimSpatial + 2))
{
throw std::runtime_error("wrong! inconsistent dimension");
}
return RunAvgPoolBwd<NDimSpatial>(arg);
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(Tensor<DInDataType>& dinput,
const Tensor<DOutDataType>& doutput,
std::vector<ck::index_t> window_spatial_lengths,
std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> window_dilations,
std::vector<ck::index_t> dinput_left_pads,
std::vector<ck::index_t> dinput_right_pads)
{
if(window_spatial_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || dinput_left_pads.size() != NDimSpatial ||
dinput_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return Argument{dinput,
doutput,
window_spatial_lengths,
window_strides,
window_dilations,
dinput_left_pads,
dinput_right_pads};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceAvgPoolBwd"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
......@@ -125,7 +125,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(g, n, c, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_ncw,
......@@ -201,7 +201,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(g, n, c, hi, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_nchw,
......@@ -299,7 +299,7 @@ struct ReferenceConvBwdData : public device::BaseOperator
arg.in_element_op_(v_in, v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(g, n, c, di, hi, wi) = ck::type_convert<InDataType>(v_in);
};
make_ParallelTensorFunctor(f_ncdhw,
......
......@@ -39,6 +39,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& /*in_right_pads*/)
: in_(in),
......@@ -46,6 +47,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices_(out_indices),
window_spatial_lengths_(window_spatial_lengths),
window_strides_(window_strides),
window_dilations_(window_dilations),
in_left_pads_(in_left_pads),
reduceLength_(1)
{
......@@ -58,6 +60,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices_;
const std::vector<ck::index_t>& window_spatial_lengths_;
const std::vector<ck::index_t>& window_strides_;
const std::vector<ck::index_t>& window_dilations_;
const std::vector<ck::index_t>& in_left_pads_;
int reduceLength_;
};
......@@ -85,14 +88,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
ck::index_t wi = wo * arg.window_strides_[2] +
x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
......@@ -136,14 +142,17 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t z = 0; z < arg.window_spatial_lengths_[0]; ++z)
{
ck::index_t di = do_ * arg.window_strides_[0] + z - arg.in_left_pads_[0];
ck::index_t di = do_ * arg.window_strides_[0] +
z * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[1]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[1] + y - arg.in_left_pads_[1];
ck::index_t hi = ho * arg.window_strides_[1] +
y * arg.window_dilations_[1] - arg.in_left_pads_[1];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[2]; ++x)
{
ck::index_t wi =
wo * arg.window_strides_[2] + x - arg.in_left_pads_[2];
ck::index_t wi = wo * arg.window_strides_[2] +
x * arg.window_dilations_[2] -
arg.in_left_pads_[2];
if(di >= 0 &&
di < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
hi >= 0 &&
......@@ -202,10 +211,12 @@ struct ReferencePoolingFwd : public device::BaseOperator
for(ck::index_t y = 0; y < arg.window_spatial_lengths_[0]; ++y)
{
ck::index_t hi = ho * arg.window_strides_[0] + y - arg.in_left_pads_[0];
ck::index_t hi = ho * arg.window_strides_[0] +
y * arg.window_dilations_[0] - arg.in_left_pads_[0];
for(ck::index_t x = 0; x < arg.window_spatial_lengths_[1]; ++x)
{
ck::index_t wi = wo * arg.window_strides_[1] + x - arg.in_left_pads_[1];
ck::index_t wi = wo * arg.window_strides_[1] +
x * arg.window_dilations_[1] - arg.in_left_pads_[1];
if(hi >= 0 &&
hi < static_cast<ck::index_t>(arg.in_.mDesc.GetLengths()[2]) &&
wi >= 0 &&
......@@ -308,6 +319,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
Tensor<IndexDataType>& out_indices,
const std::vector<ck::index_t>& window_spatial_lengths,
const std::vector<ck::index_t>& window_strides,
const std::vector<ck::index_t>& window_dilations,
const std::vector<ck::index_t>& in_left_pads,
const std::vector<ck::index_t>& in_right_pads)
{
......@@ -316,6 +328,7 @@ struct ReferencePoolingFwd : public device::BaseOperator
out_indices,
window_spatial_lengths,
window_strides,
window_dilations,
in_left_pads,
in_right_pads};
}
......
......@@ -17,6 +17,7 @@ namespace instance {
using F64 = double;
using F32 = float;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using I32 = int32_t;
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
void add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
......@@ -36,7 +36,8 @@ void add_device_batched_gemm_xdl_bf16_bf16_bf16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_xdl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -56,7 +57,8 @@ void add_device_batched_gemm_xdl_f16_f16_f16_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -76,7 +78,8 @@ void add_device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemm<Col,
Row,
......@@ -120,7 +123,7 @@ void add_device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename CLayout,
......@@ -151,7 +154,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<CDataType, float>)
{
......@@ -176,8 +179,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
#endif
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -200,8 +205,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_f16_f16_f16_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<ADataType, bhalf_t> && is_same_v<BDataType, bhalf_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -224,8 +231,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_bf16_bf16_bf16_gkm_gnk_gmn_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
......@@ -248,7 +257,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances(op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -14,7 +14,7 @@
using CDE0ElementOp = ck::tensor_operation::element_wise::AddRelu;
using CDE1ElementOp = ck::tensor_operation::element_wise::Add;
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -137,3 +137,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -91,3 +91,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
......@@ -58,7 +58,8 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
......@@ -100,7 +101,7 @@ void add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
#endif
template <typename ADataType,
typename B0DataType,
typename B1DataType,
......@@ -147,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t> &&
Acc0BiasDataType::Size() == 1 &&
......@@ -164,6 +165,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16> &&
Acc0BiasDataType::Size() == 1 &&
......@@ -180,6 +183,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmGemm<Row,
Col,
......@@ -111,3 +111,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -19,7 +19,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
......@@ -123,7 +123,8 @@ void add_device_batched_gemm_multi_d_dl_f16_f16_f16_gmk_gnk_gmn_irregular_instan
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gkm_gkn_gmn_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmMultiD<Col,
Row,
......@@ -227,7 +228,7 @@ void add_device_batched_gemm_multi_d_dl_i8_i8_i8_gmk_gnk_gmn_irregular_instances
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ALayout,
typename BLayout,
typename ELayout,
......@@ -262,7 +263,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<EDataType, half_t>)
{
......@@ -295,6 +296,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<EDataType, int8_t>)
{
......@@ -327,7 +330,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatche
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -119,3 +119,4 @@ struct DeviceOperationInstanceFactory<
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP16
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
......@@ -58,7 +58,8 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances(
std::vector<std::unique_ptr<
DeviceBatchedGemmSoftmaxGemmPermute<2,
......@@ -100,6 +101,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough,
MaskingSpecialization::MaskDisabled>>>&
instances);
#endif
template <typename ADataType,
typename B0DataType,
......@@ -146,7 +148,7 @@ struct DeviceOperationInstanceFactory<
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<ADataType, half_t> && is_same_v<B0DataType, half_t> &&
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
......@@ -161,6 +163,8 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<ADataType, BF16> && is_same_v<B0DataType, BF16> &&
is_same_v<B1DataType, BF16> && is_same_v<CDataType, BF16>)
{
......@@ -175,6 +179,7 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// float
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -65,7 +65,8 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_f32_mnnn
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_FP64
// double
void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_kknn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -114,7 +115,7 @@ void add_device_contraction_bilinear_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_f64_mnnn
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
// Contraction + Bilinear
template <index_t NumDimM,
index_t NumDimN,
......@@ -149,7 +150,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<DDataType, float> && is_same_v<EDataType, float>)
{
......@@ -165,7 +166,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<DDataType, double> && is_same_v<EDataType, double>)
{
......@@ -181,7 +183,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_FP32
// float
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -65,7 +65,8 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f32_f32_f32_mnn_instanc
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
#ifdef CK_ENABLE_FP64
// double
void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_kkn_instance(
std::vector<std::unique_ptr<DeviceContractionMultipleD<2,
......@@ -114,7 +115,7 @@ void add_device_contraction_scale_m2_n2_k2_xdl_c_shuffle_f64_f64_f64_mnn_instanc
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
// Contraction + Scale
template <index_t NumDimM,
index_t NumDimN,
......@@ -148,7 +149,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ADataType, float> && is_same_v<BDataType, float> &&
is_same_v<EDataType, float>)
{
......@@ -164,7 +165,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_FP64
if constexpr(is_same_v<ADataType, double> && is_same_v<BDataType, double> &&
is_same_v<EDataType, double>)
{
......@@ -180,7 +182,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceContra
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
......
......@@ -16,7 +16,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#ifdef CK_ENABLE_BF16
// conv1d backward data
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
......@@ -29,17 +29,20 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvBwdData<1, NWC, KXC, NWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#ifdef __int8__
#endif
#ifdef CK_ENABLE_INT8
void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<1,
NWC,
......@@ -52,6 +55,7 @@ void add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_BF16
// conv2d backward data
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
......@@ -64,7 +68,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
......@@ -76,7 +81,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
......@@ -88,7 +94,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef __int8__
#endif
#ifdef CK_ENABLE_INT8
void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
......@@ -101,6 +108,8 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef DL_KERNELS
#ifdef CK_ENABLE_FP16
// conv2d dl
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
......@@ -113,7 +122,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
......@@ -125,7 +135,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef __int8__
#endif
#ifdef CK_ENABLE_INT8
void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<2,
NHWC,
......@@ -138,6 +149,8 @@ void add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(
PassThrough,
PassThrough>>>& instances);
#endif
#endif
#ifdef CK_ENABLE_BF16
// conv3d backward data
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
......@@ -150,7 +163,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP16
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
......@@ -162,7 +176,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
......@@ -174,7 +189,8 @@ void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#ifdef __int8__
#endif
#ifdef CK_ENABLE_INT8
void add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvBwdData<3,
NDHWC,
......@@ -229,20 +245,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_bf16_instances(op_ptrs);
}
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv1d_bwd_data_xdl_nwc_kxc_nwk_int8_instances(op_ptrs);
}
......@@ -255,26 +274,35 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v<OutDataType, float>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
#endif
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#ifdef DL_KERNELS
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
#endif
}
#endif
}
......@@ -286,20 +314,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f32_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_f16_instances(op_ptrs);
}
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
#endif
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> && is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances(op_ptrs);
}
#ifdef __int8__
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
#endif
#ifdef CK_ENABLE_INT8
if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_int8_instances(op_ptrs);
}
......
......@@ -18,11 +18,17 @@ namespace device {
namespace instance {
// conv2d forward
#ifdef CK_ENABLE_FP16
void add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_BF16
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
......@@ -34,17 +40,14 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_FP32
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
std::vector<std::unique_ptr<
DeviceConvFwd<2, NHWC, KYXC, NHWK, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef CK_ENABLE_INT8
void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
std::vector<std::unique_ptr<DeviceConvFwd<2,
NHWC,
......@@ -56,6 +59,7 @@ void add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <ck::index_t NumDimSpatial,
typename InLayout,
......@@ -99,23 +103,29 @@ struct DeviceOperationInstanceFactory<
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(op_ptrs);
}
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
add_device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
is_same_v<WeiDataType, ck::bhalf_t> &&
is_same_v<OutDataType, ck::bhalf_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(op_ptrs);
}
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
is_same_v<OutDataType, int8_t>)
{
add_device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(op_ptrs);
}
#endif
}
return op_ptrs;
......
......@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -77,3 +77,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceElemen
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -17,12 +17,17 @@ namespace tensor_operation {
namespace device {
namespace instance {
#if defined(__fp16__) && defined(DL_KERNELS)
#if defined(CK_ENABLE_FP16) && defined(DL_KERNELS)
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -33,6 +38,11 @@ void add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -43,6 +53,11 @@ void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Row, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -53,12 +68,17 @@ void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
void add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(__fp32__) && defined(DL_KERNELS)
#if defined(CK_ENABLE_FP32) && defined(DL_KERNELS)
void add_device_gemm_dl_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -79,7 +99,7 @@ void add_device_gemm_dl_f32_f32_f32_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#if defined(__int8__) && defined(DL_KERNELS)
#if defined(CK_ENABLE_INT8) && defined(DL_KERNELS)
void add_device_gemm_dl_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
......@@ -120,7 +140,7 @@ void add_device_gemm_dl_i8_i8_i8_mk_nk_mn_irregular_instances(
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __int8__
#ifdef CK_ENABLE_INT8
void add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
......@@ -141,7 +161,7 @@ void add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, int8_t, int8_t, int8_t, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __fp16__
#ifdef CK_ENABLE_FP16
void add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Row, Col, Row, F16, F16, F16, PassThrough, PassThrough, PassThrough>>>&
......@@ -188,7 +208,7 @@ void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(
instances);
#endif
#ifdef __bf16__
#ifdef CK_ENABLE_BF16
void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
......@@ -209,7 +229,7 @@ void add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, BF16, BF16, BF16, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __fp32__
#ifdef CK_ENABLE_FP32
void add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instances(
std::vector<std::unique_ptr<
DeviceGemm<Col, Row, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
......@@ -250,7 +270,7 @@ void add_device_gemm_xdl_f32_f32_f32_mk_nk_mn_instances(
DeviceGemm<Row, Col, Row, F32, F32, F32, PassThrough, PassThrough, PassThrough>>>&
instances);
#endif
#ifdef __fp64__
#ifdef CK_ENABLE_FP64
void add_device_gemm_xdl_f64_f64_f64_km_kn_mn_instances(
std::vector<std::unique_ptr<
......@@ -343,6 +363,7 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instances(op_ptrs);
}
}
#ifdef CK_ENABLE_FP16
else if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, half_t> &&
is_same_v<CDataType, half_t>)
{
......@@ -353,6 +374,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(op_ptrs);
}
......@@ -363,6 +385,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_mk_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
add_device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instances(op_ptrs);
......@@ -374,6 +397,7 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_kn_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_kn_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_kn_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(op_ptrs);
}
......@@ -384,10 +408,13 @@ struct DeviceOperationInstanceFactory<
#ifdef DL_KERNELS
add_device_gemm_dl_f16_f16_f16_km_nk_mn_instances(op_ptrs);
add_device_gemm_dl_f16_f16_f16_km_nk_mn_irregular_instances(op_ptrs);
add_device_gemm_dl_dpp8_f16_f16_f16_km_nk_mn_instances(op_ptrs);
#endif
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(op_ptrs);
}
}
#endif
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<ADataType, ck::bhalf_t> && is_same_v<BDataType, ck::bhalf_t> &&
is_same_v<CDataType, ck::bhalf_t>)
{
......@@ -412,7 +439,8 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs);
}
}
#ifdef __int8__
#endif
#ifdef CK_ENABLE_INT8
else if constexpr(is_same_v<ADataType, int8_t> && is_same_v<BDataType, int8_t> &&
is_same_v<CDataType, int8_t>)
{
......
......@@ -9,7 +9,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -170,3 +170,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
......@@ -11,7 +11,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#ifdef CK_ENABLE_FP16
namespace ck {
namespace tensor_operation {
namespace device {
......@@ -144,3 +144,4 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment