Commit cc6a534f authored by aska-0096's avatar aska-0096
Browse files

Merge branch 'develop' of...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into navi3x_md_bgemm_conv_gemmsoftmaxgemm
parents 27dc055b cb3fac4d
...@@ -47,3 +47,7 @@ build* ...@@ -47,3 +47,7 @@ build*
# GDB temporary files # GDB temporary files
.gdb_history .gdb_history
install.dir* install.dir*
# directories containing generated documentation
docs/source/_build/
docs/docBin/
# Change Log for Composable Kernel
Full documentation for Composable Kernel is not yet available.
## CK 0.1.1 for ROCm 5.5.0
### Fixed
- Fixed a bug in 6-dimensional kernels (#555).
- Fixed grouped ConvBwdWeight test case failure (#524).
### Optimizations
- Improve proformance of normalization kernel
### Added
- Added user tutorial (#563).
- Added more instances for irregular GEMM sizes (#560).
- Added inter-wave consumer-producer programming model for GEMM kernels (#310).
- Added multi-D GEMM client APIs (#534).
- Added multi-embeddings support (#542).
- Added Navi3x blockwise GEMM and real GEMM support (#541).
### Changed
- Changed ...
...@@ -60,7 +60,7 @@ RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb ...@@ -60,7 +60,7 @@ RUN dpkg -i dumb-init_*.deb && rm dumb-init_*.deb
ARG PREFIX=/opt/rocm ARG PREFIX=/opt/rocm
# Install packages for processing the performance results # Install packages for processing the performance results
RUN pip3 install --upgrade pip RUN pip3 install --upgrade pip
RUN pip3 install sqlalchemy RUN pip3 install sqlalchemy==1.4.46
RUN pip3 install pymysql RUN pip3 install pymysql
RUN pip3 install pandas RUN pip3 install pandas
RUN pip3 install setuptools-rust RUN pip3 install setuptools-rust
......
...@@ -19,7 +19,14 @@ def runShell(String command){ ...@@ -19,7 +19,14 @@ def runShell(String command){
} }
def getDockerImageName(){ def getDockerImageName(){
def img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}" def img
if (params.COMPILER_COMMIT == ""){
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}"
}
else{
def commit = "${params.COMPILER_COMMIT}"[0..6]
img = "${env.CK_DOCKERHUB}:ck_ub20.04_rocm${params.ROCMVERSION}_${params.COMPILER_VERSION}_${commit}"
}
return img return img
} }
...@@ -550,8 +557,9 @@ def process_results(Map conf=[:]){ ...@@ -550,8 +557,9 @@ def process_results(Map conf=[:]){
} }
//launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version //launch develop branch daily at 23:00 UT in FULL_QA mode and at 19:00 UT with latest staging compiler version
CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true;COMPILER_VERSION=release CRON_SETTINGS = BRANCH_NAME == "develop" ? '''0 23 * * * % RUN_FULL_QA=true
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open''' : "" 0 21 * * * % RUN_FULL_QA=false;COMPILER_VERSION=release;COMPILER_COMMIT=
0 19 * * * % BUILD_DOCKER=true;COMPILER_VERSION=amd-stg-open;COMPILER_COMMIT=''' : ""
pipeline { pipeline {
agent none agent none
...@@ -568,16 +576,16 @@ pipeline { ...@@ -568,16 +576,16 @@ pipeline {
description: "Force building docker image (default: false), set to true if docker image needs to be updated.") description: "Force building docker image (default: false), set to true if docker image needs to be updated.")
string( string(
name: 'ROCMVERSION', name: 'ROCMVERSION',
defaultValue: '5.3', defaultValue: '5.4.3',
description: 'Specify which ROCM version to use: 5.2.3, or 5.3 (default), etc.') description: 'Specify which ROCM version to use: 5.4.3 (default).')
string( string(
name: 'COMPILER_VERSION', name: 'COMPILER_VERSION',
defaultValue: 'release', defaultValue: 'amd-stg-open',
description: 'Specify which version of compiler to use: ck-9110, release (default), or amd-stg-open.') description: 'Specify which version of compiler to use: ck-9110, release, or amd-stg-open (default).')
string( string(
name: 'COMPILER_COMMIT', name: 'COMPILER_COMMIT',
defaultValue: '', defaultValue: '5541927df00eabd6a110180170eca7785d436ee3',
description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit (default), or use 8a82e4eb7ba28521ba9a9424a0315a8a16590424 commit of amd-stg-open branch.') description: 'Specify which commit of compiler branch to use: leave empty to use the latest commit, or use 5541927df00eabd6a110180170eca7785d436ee3 (default) commit of amd-stg-open branch.')
string( string(
name: 'BUILD_COMPILER', name: 'BUILD_COMPILER',
defaultValue: 'hipcc', defaultValue: 'hipcc',
......
...@@ -83,7 +83,7 @@ int main(int argc, char* argv[]) ...@@ -83,7 +83,7 @@ int main(int argc, char* argv[])
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
...@@ -92,7 +92,7 @@ int main(int argc, char* argv[]) ...@@ -92,7 +92,7 @@ int main(int argc, char* argv[])
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
...@@ -88,7 +88,7 @@ int main(int argc, char* argv[]) ...@@ -88,7 +88,7 @@ int main(int argc, char* argv[])
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
...@@ -84,7 +84,7 @@ int main(int argc, char* argv[]) ...@@ -84,7 +84,7 @@ int main(int argc, char* argv[])
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
add_executable(client_gemm_add_add_reduce_normalize gemm_add_add_layernorm.cpp) add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_reduce_normalize PRIVATE composable_kernel::device_operations) target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_operations)
add_executable(client_gemm_add_relu_add_layernorm_welford gemm_add_relu_add_layernorm_welford.cpp)
target_link_libraries(client_gemm_add_relu_add_layernorm_welford PRIVATE composable_kernel::device_operations)
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_elementwise.hpp" #include "ck/tensor_operation/gpu/device/impl/device_elementwise_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/device_elementwise_instance.hpp"
...@@ -190,7 +190,7 @@ int main() ...@@ -190,7 +190,7 @@ int main()
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) { [](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout); using Layout = decltype(layout);
if(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value) if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{ {
return (nRow - 1) * stride + nCol; return (nRow - 1) * stride + nCol;
} }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using F16 = ck::half_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AddReluAdd = ck::tensor_operation::element_wise::AddReluAdd;
// DataType
using ADataType = F16;
using BDataType = F16;
using D0DataType = F16;
using D1DataType = F16;
using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16;
// Layout
using ALayout = Row;
using BLayout = Col;
using D0Layout = Row;
using D1Layout = Row;
using HLayout = Row;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = AddReluAdd;
using HElementOp = PassThrough;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}, mMemSize_(mem_size)
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
void SetZero() const { (void)hipMemset(p_mem_, 0, mMemSize_); }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
std::size_t mMemSize_;
};
int main(int argc, char* argv[])
{
// GEMM shape
ck::index_t M = 1024;
ck::index_t N = 1024;
ck::index_t K = 1024;
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = N;
ck::index_t StrideH = N;
float epsilon = 1e-5;
auto f_matrix_space_size =
[](std::size_t nRow, std::size_t nCol, std::size_t stride, auto layout) {
using Layout = decltype(layout);
if constexpr(std::is_same<Layout, ck::tensor_layout::gemm::RowMajor>::value)
{
return (nRow - 1) * stride + nCol;
}
else
{
return (nCol - 1) * stride + nRow;
}
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * f_matrix_space_size(M, K, StrideA, ALayout{}));
SimpleDeviceMem b_device_buf(sizeof(BDataType) * f_matrix_space_size(K, N, StrideB, BLayout{}));
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) *
f_matrix_space_size(M, N, StrideD0, D0Layout{}));
SimpleDeviceMem d1_device_buf(sizeof(D1DataType) *
f_matrix_space_size(M, N, StrideD1, D1Layout{}));
SimpleDeviceMem gamma_device_buf(sizeof(GammaDataType) * N);
SimpleDeviceMem beta_device_buf(sizeof(BetaDataType) * N);
SimpleDeviceMem h_device_buf(sizeof(HDataType) * f_matrix_space_size(M, N, StrideH, HLayout{}));
using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm<
ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
HLayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
GammaDataType,
BetaDataType,
HDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::AddReluAdd,
ck::tensor_operation::element_wise::PassThrough>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
const auto h_element_op = HElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = std::numeric_limits<float>::max();
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{StrideD0, StrideD1},
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
h_device_buf.SetZero();
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t num_byte =
sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
(sizeof(D0DataType) + sizeof(D1DataType) + sizeof(HDataType)) * M * N +
(sizeof(GammaDataType) + sizeof(BetaDataType)) * N;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << gb_per_sec << " GB/s, "
<< op_name << std::endl;
if(ave_time < best_ave_time)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_gb_per_sec << " GB/s, "
<< best_op_name << std::endl;
// run the best intance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
{StrideD0, StrideD1},
StrideH,
epsilon,
a_element_op,
b_element_op,
cde_element_op,
h_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
SimpleDeviceMem workspace_dev(workspace_sz);
op_ptr->SetWorkSpacePointer(argument_ptr.get(), workspace_dev.GetDeviceBuffer());
h_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
\ No newline at end of file
...@@ -4,3 +4,6 @@ target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device ...@@ -4,3 +4,6 @@ target_link_libraries(client_contraction_scale PRIVATE composable_kernel::device
add_executable(client_contraction_bilinear contraction_bilinear.cpp) add_executable(client_contraction_bilinear contraction_bilinear.cpp)
target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations) target_link_libraries(client_contraction_bilinear PRIVATE composable_kernel::device_operations)
add_executable(contraction_g1m2n3k1_add_xdl_fp16 contraction_g1m2n3k1_add_xdl_fp16.cpp)
target_link_libraries(contraction_g1m2n3k1_add_xdl_fp16 PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iomanip>
#include <numeric>
#include <vector>
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_permute.hpp"
#include "ck/library/utility/numeric.hpp"
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Add = ck::tensor_operation::element_wise::Add;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = Add;
using ADataType = F16;
using BDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using DDataType = F16;
using DsDataType = ck::Tuple<DDataType>;
using EDataType = F16;
static constexpr ck::index_t NumDimG = 1;
static constexpr ck::index_t NumDimM = 2;
static constexpr ck::index_t NumDimN = 3;
static constexpr ck::index_t NumDimK = 1;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
ck::index_t G0 = 1;
ck::index_t M0 = 64;
ck::index_t M1 = 256;
ck::index_t N0 = 3;
ck::index_t N1 = 12;
ck::index_t N2 = 64;
ck::index_t K0 = 768;
// A[M0, M1, M2, K0]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, M0, M1, K0};
std::vector<ck::index_t> a_gs_ms_ks_strides{M0 * M1 * K0, M1 * K0, K0, 1};
// B[N0, N1, N2, K0]
std::vector<ck::index_t> b_gs_ns_ks_lengths{G0, N0, N1, N2, K0};
std::vector<ck::index_t> b_gs_ns_ks_strides{N0 * N1 * N2 * K0, N1 * N2 * K0, N2 * K0, K0, 1};
// D[N0, M0, N1, M1, N2]
std::vector<ck::index_t> d_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> d_gs_ms_ns_strides{N0 * N1 * N2, 0, 0, N1 * N2, N2, 1};
// E[N0 M0 N1 N2 M1]
std::vector<ck::index_t> e_gs_ms_ns_lengths{G0, M0, M1, N0, N1, N2};
std::vector<ck::index_t> e_gs_ms_ns_strides{
M0 * M1 * N0 * N1 * N2, N1 * N2 * M1, 1, M0 * N1 * N2 * M1, M1 * N2, M1};
auto f_tensor_space_size = [](auto lengths, auto strides) {
std::size_t space_size = 1;
for(std::size_t i = 0; i < lengths.size(); ++i)
{
space_size += (lengths[i] - 1) * strides[i];
}
return space_size;
};
SimpleDeviceMem a_device_buf(sizeof(ADataType) *
f_tensor_space_size(a_gs_ms_ks_lengths, a_gs_ms_ks_strides));
SimpleDeviceMem b_device_buf(sizeof(BDataType) *
f_tensor_space_size(b_gs_ns_ks_lengths, b_gs_ns_ks_strides));
SimpleDeviceMem d_device_buf(sizeof(DDataType) *
f_tensor_space_size(d_gs_ms_ns_lengths, d_gs_ms_ns_strides));
SimpleDeviceMem e_device_buf(sizeof(EDataType) *
f_tensor_space_size(e_gs_ms_ns_lengths, e_gs_ms_ns_strides));
using DeviceOp = ck::tensor_operation::device::DeviceBatchedContractionMultipleD<
NumDimG,
NumDimM,
NumDimN,
NumDimK,
ADataType,
BDataType,
DsDataType,
EDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::Add>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto cde_element_op = CDEElementOp{};
std::string best_op_name;
bool found = false;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device operation instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr =
op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, 1>{d_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b_gs_ns_ks_lengths,
b_gs_ns_ks_strides,
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_lengths},
std::array<std::vector<ck::index_t>, 1>{d_gs_ms_ns_strides},
e_gs_ms_ns_lengths,
e_gs_ms_ns_strides,
a_element_op,
b_element_op,
cde_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
ck::index_t M = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG, NumDimM, 1, std::multiplies<>{});
ck::index_t N = ck::accumulate_n<ck::index_t>(
e_gs_ms_ns_lengths.begin() + NumDimG + NumDimM, NumDimN, 1, std::multiplies<>{});
ck::index_t K = ck::accumulate_n<ck::index_t>(
a_gs_ms_ks_lengths.begin() + NumDimG + NumDimM, NumDimK, 1, std::multiplies<>{});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(DDataType) * M * N + sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
found = true;
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
return 0;
}
...@@ -16,7 +16,7 @@ using XDataType = ck::half_t; ...@@ -16,7 +16,7 @@ using XDataType = ck::half_t;
using GammaDataType = ck::half_t; using GammaDataType = ck::half_t;
using BetaDataType = ck::half_t; using BetaDataType = ck::half_t;
using YDataType = ck::half_t; using YDataType = ck::half_t;
using AccDataType = float; using ComputeDataType = float;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
constexpr int Rank = 2; constexpr int Rank = 2;
...@@ -54,7 +54,7 @@ int main(int argc, char* argv[]) ...@@ -54,7 +54,7 @@ int main(int argc, char* argv[])
using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType, using DeviceOp = ck::tensor_operation::device::DeviceNormalization<XDataType,
GammaDataType, GammaDataType,
BetaDataType, BetaDataType,
AccDataType, ComputeDataType,
YDataType, YDataType,
PassThrough, PassThrough,
Rank, Rank,
......
...@@ -47,8 +47,8 @@ int main(int argc, char* argv[]) ...@@ -47,8 +47,8 @@ int main(int argc, char* argv[])
ck::index_t num_elements = ck::index_t num_elements =
std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>()); std::accumulate(in_lengths.begin(), in_lengths.end(), 1, std::multiplies<ck::index_t>());
AccDataType alpha{2.0f}; double alpha{2.0};
AccDataType beta{2.0f}; double beta{2.0};
SimpleDeviceMem in(sizeof(InDataType) * num_elements); SimpleDeviceMem in(sizeof(InDataType) * num_elements);
SimpleDeviceMem out(sizeof(OutDataType) * num_elements); SimpleDeviceMem out(sizeof(OutDataType) * num_elements);
...@@ -82,8 +82,8 @@ int main(int argc, char* argv[]) ...@@ -82,8 +82,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
&alpha, alpha,
&beta, beta,
in.GetDeviceBuffer(), in.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
...@@ -129,8 +129,8 @@ int main(int argc, char* argv[]) ...@@ -129,8 +129,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths, auto argument_ptr = op_ptr->MakeArgumentPointer(in_lengths,
in_strides, in_strides,
reduce_dims, reduce_dims,
&alpha, alpha,
&beta, beta,
in.GetDeviceBuffer(), in.GetDeviceBuffer(),
out.GetDeviceBuffer(), out.GetDeviceBuffer(),
PassThrough{}, PassThrough{},
......
add_executable(client_fused_attention fused_attention.cpp) add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations) target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_operations)
add_executable(client_fused_attention_bias fused_attention_bias.cpp)
target_link_libraries(client_fused_attention_bias PRIVATE composable_kernel::device_operations)
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using B0ElementOp = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = ck::tensor_operation::element_wise::ScaleAdd;
using B1ElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
constexpr static auto MaskingSpec =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
using ADataType = ck::half_t;
using B0DataType = ck::half_t;
using B1DataType = ck::half_t;
using CDataType = ck::half_t;
using D0DataType = ck::half_t;
using AccDataType = float;
struct SimpleDeviceMem
{
SimpleDeviceMem() = delete;
SimpleDeviceMem(std::size_t mem_size) : p_mem_{}
{
(void)hipMalloc(static_cast<void**>(&p_mem_), mem_size);
}
void* GetDeviceBuffer() { return p_mem_; }
~SimpleDeviceMem() { (void)hipFree(p_mem_); }
void* p_mem_;
};
int main(int argc, char* argv[])
{
int G0 = 48;
int G1 = 16;
int M = 1024;
int N = 1024;
int K = 64;
int O = 64;
// A layout [G0, M, G1, K]
std::vector<ck::index_t> a_gs_ms_ks_lengths{G0, G1, M, K};
std::vector<ck::index_t> a_gs_ms_ks_strides{M * G1 * K, K, G1 * K, 1};
// B0 layout [G0, N, G1, K]
std::vector<ck::index_t> b0_gs_ns_ks_lengths{G0, G1, N, K};
std::vector<ck::index_t> b0_gs_ns_ks_strides{N * G1 * K, K, G1 * K, 1};
// B1 layout [G0, N, G1, O]
std::vector<ck::index_t> b1_gs_os_ns_lengths{G0, G1, O, N};
std::vector<ck::index_t> b1_gs_os_ns_strides{N * G1 * O, O, 1, G1 * O};
// C layout [G0, M, G1, O]
std::vector<ck::index_t> c_gs_ms_os_lengths{G0, G1, M, O};
std::vector<ck::index_t> c_gs_ms_os_strides{M * G1 * O, O, G1 * O, 1};
// D layout [G0, M, G1, N]
std::vector<ck::index_t> d0_gs_ms_ns_lengths{G0, G1, M, N};
std::vector<ck::index_t> d0_gs_ms_ns_strides{M * G1 * N, N, G1 * N, 1};
SimpleDeviceMem a_device_buf(sizeof(ADataType) * G0 * G1 * M * K);
SimpleDeviceMem b0_device_buf(sizeof(B0DataType) * G0 * G1 * N * K);
SimpleDeviceMem d0_device_buf(sizeof(D0DataType) * G0 * G1 * M * N);
SimpleDeviceMem b1_device_buf(sizeof(B1DataType) * G0 * G1 * O * N);
SimpleDeviceMem c_device_buf(sizeof(CDataType) * G0 * G1 * M * O);
using DeviceOp =
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<2,
1,
1,
1,
1,
ADataType,
B0DataType,
B1DataType,
CDataType,
ck::Tuple<D0DataType>,
ck::Tuple<>,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp,
MaskingSpec>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
std::string best_op_name;
int best_op_id = -1;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
// profile device op instances
std::cout << "Run all instances and do timing" << std::endl;
for(int i = 0; i < op_ptrs.size(); ++i)
{
auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
std::string op_name = op_ptr->GetTypeString();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
float ave_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = (size_t(M) * N * K * 2 + size_t(M) * N * O * 2) * G0 * G1;
std::size_t num_btype = (sizeof(ADataType) * M * K + sizeof(B0DataType) * K * N +
sizeof(B1DataType) * N * O + sizeof(CDataType) * M * O +
sizeof(D0DataType) * M * N) *
G0 * G1;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << op_name << std::endl;
if(tflops > best_tflops)
{
best_op_id = i;
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
}
}
else
{
std::cout << op_name << " does not support this problem" << std::endl;
}
}
std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, "
<< best_gb_per_sec << " GB/s, " << best_op_name << std::endl;
// run the best instance
{
auto& op_ptr = op_ptrs[best_op_id];
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer(
a_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
c_device_buf.GetDeviceBuffer(),
std::array<void*, 1>{d0_device_buf.GetDeviceBuffer()}, // p_acc0_biases
{}, // p_acc1_biases
a_gs_ms_ks_lengths,
a_gs_ms_ks_strides,
b0_gs_ns_ks_lengths,
b0_gs_ns_ks_strides,
b1_gs_os_ns_lengths,
b1_gs_os_ns_strides,
c_gs_ms_os_lengths,
c_gs_ms_os_strides,
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_lengths}, // acc0_biases_gs_ms_ns_lengths
std::array<std::vector<ck::index_t>, 1>{
d0_gs_ms_ns_strides}, // acc0_biases_gs_ms_ns_strides
{}, // acc1_biases_gs_ms_os_lengths
{}, // acc1_biases_gs_ms_os_strides
AElementOp{},
B0ElementOp{},
Acc0ElementOp{1 / sqrtf(K)},
B1ElementOp{},
CElementOp{});
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, false});
}
std::cout << "Done" << std::endl;
}
return 0;
}
add_executable(client_grouped_conv2d_bwd_weight grouped_conv2d_bwd_weight.cpp) add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
target_link_libraries(client_grouped_conv2d_bwd_weight PRIVATE composable_kernel::device_operations) add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)
add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp)
target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_operations)
target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_operations)
...@@ -13,27 +13,8 @@ ...@@ -13,27 +13,8 @@
#include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp" #include "ck/tensor_operation/gpu/device/device_conv_fwd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 32;
static constexpr ck::index_t N = 256;
static constexpr ck::index_t K = 192;
static constexpr ck::index_t C = 192;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28;
struct SimpleDeviceMem struct SimpleDeviceMem
{ {
SimpleDeviceMem() = delete; SimpleDeviceMem() = delete;
...@@ -50,22 +31,93 @@ struct SimpleDeviceMem ...@@ -50,22 +31,93 @@ struct SimpleDeviceMem
void* p_mem_; void* p_mem_;
}; };
int main() template <ck::index_t NumDimSpatial>
std::size_t GetFlops(ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{ {
std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi}; // 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X}; return static_cast<std::size_t>(2) * G * N * K * C *
std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo}; std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()) *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<>());
}
std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1}; template <typename InDataType, ck::index_t NumDimSpatial>
std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1}; std::size_t GetInputByte(ck::index_t G,
std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1}; ck::index_t N,
std::array<ck::index_t, NumDimSpatial> input_right_pads{1, 1}; ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return sizeof(InDataType) * (G * N * C *
std::accumulate(std::begin(input_spatial_lengths),
std::end(input_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
ck::index_t split_k = 2; template <typename WeiDataType, ck::index_t NumDimSpatial>
std::size_t GetWeightByte(ck::index_t G,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return sizeof(WeiDataType) * (G * K * C *
std::accumulate(std::begin(filter_spatial_lengths),
std::end(filter_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<>()));
}
SimpleDeviceMem in(sizeof(InDataType) * G * N * Hi * Wi * C); template <typename OutDataType, ck::index_t NumDimSpatial>
SimpleDeviceMem wei(sizeof(WeiDataType) * G * K * Y * X * C); std::size_t GetOutputByte(ck::index_t G,
SimpleDeviceMem out(sizeof(OutDataType) * G * N * Ho * Wo * K); ck::index_t N,
ck::index_t K,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return sizeof(OutDataType) * (G * N * K *
std::accumulate(std::begin(output_spatial_lengths),
std::end(output_spatial_lengths),
static_cast<std::size_t>(1),
std::multiplies<std::size_t>()));
}
template <ck::index_t NumDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
bool run_grouped_conv_bwd_weight(
ck::index_t G,
ck::index_t N,
ck::index_t K,
ck::index_t C,
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
const std::array<ck::index_t, NumDimSpatial>& input_left_pads,
const std::array<ck::index_t, NumDimSpatial>& input_right_pads)
{
ck::index_t split_k = 2;
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths));
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths));
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths));
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial, using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
InLayout, InLayout,
...@@ -120,10 +172,12 @@ int main() ...@@ -120,10 +172,12 @@ int main()
{ {
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true}); float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
std::size_t flop = std::size_t(2) * G * N * K * C * Ho * Wo * Y * X; std::size_t flop =
std::size_t num_bytes = sizeof(InDataType) * G * N * Hi * Wi * C + GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths);
sizeof(WeiDataType) * G * K * Y * X * C + std::size_t num_bytes =
sizeof(OutDataType) * G * N * Ho * Wo * K; GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) +
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
float tflops = static_cast<float>(flop) / 1.E9 / avg_time; float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
float gb_per_sec = num_bytes / 1.E6 / avg_time; float gb_per_sec = num_bytes / 1.E6 / avg_time;
...@@ -149,7 +203,7 @@ int main() ...@@ -149,7 +203,7 @@ int main()
if(best_op_id < 0) if(best_op_id < 0)
{ {
std::cerr << "no suitable instance" << std::endl; std::cerr << "no suitable instance" << std::endl;
return EXIT_FAILURE; return false;
} }
std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops std::cout << "Best Perf: " << std::setw(10) << best_avg_time << " ms, " << best_tflops
...@@ -187,4 +241,6 @@ int main() ...@@ -187,4 +241,6 @@ int main()
std::cout << "Done" << std::endl; std::cout << "Done" << std::endl;
} }
return true;
} }
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
using InDataType = ck::half_t;
using WeiDataType = ck::half_t;
using OutDataType = ck::half_t;
using InLayout = ck::tensor_layout::convolution::GNHWC;
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
using OutLayout = ck::tensor_layout::convolution::GNHWK;
static constexpr ck::index_t NumDimSpatial = 2;
static constexpr ck::index_t G = 32;
static constexpr ck::index_t N = 256;
static constexpr ck::index_t K = 192;
static constexpr ck::index_t C = 192;
static constexpr ck::index_t Y = 3;
static constexpr ck::index_t X = 3;
static constexpr ck::index_t Hi = 28;
static constexpr ck::index_t Wi = 28;
static constexpr ck::index_t Ho = 28;
static constexpr ck::index_t Wo = 28;
int main()
{
return run_grouped_conv_bwd_weight<NumDimSpatial,
InDataType,
WeiDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
G, N, K, C, {Hi, Wi}, {Y, X}, {Ho, Wo}, {1, 1}, {1, 1}, {1, 1}, {1, 1})
? EXIT_SUCCESS
: EXIT_FAILURE;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment