Commit 1864dfe1 authored by Bartlomiej Kocot's avatar Bartlomiej Kocot
Browse files

Introduce profile_contraction_scale and profile_contraction_bilinear

parent 304adaad
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/ck.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
inline void collect_index_params(char* argv[],
std::vector<ck::index_t>& params,
const ck::index_t from,
const ck::index_t num)
{
for(ck::index_t p = from; p < from + num; p++)
params.push_back(std::stoi(argv[p]));
}
// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1}
// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1}
inline void
assign_default_strides(Row, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1};
}
inline void
assign_default_strides(Col, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]};
}
...@@ -30,7 +30,8 @@ set(PROFILER_SOURCES ...@@ -30,7 +30,8 @@ set(PROFILER_SOURCES
profile_batchnorm_bwd.cpp profile_batchnorm_bwd.cpp
profile_batchnorm_infer.cpp profile_batchnorm_infer.cpp
profile_grouped_gemm_fastgelu.cpp profile_grouped_gemm_fastgelu.cpp
profile_contraction.cpp profile_contraction_bilinear.cpp
profile_contraction_scale.cpp
) )
set(PROFILER_EXECUTABLE ckProfiler) set(PROFILER_EXECUTABLE ckProfiler)
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <vector> #include <vector>
#include "profiler/profile_contraction_impl.hpp" #include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp" #include "profiler_operation_registry.hpp"
enum struct ContractionMatrixLayout enum struct ContractionMatrixLayout
...@@ -24,14 +25,8 @@ enum struct ContractionDataType ...@@ -24,14 +25,8 @@ enum struct ContractionDataType
F64_F64_F64_F64, // 1 F64_F64_F64_F64, // 1
}; };
#define OP_NAME "contraction" #define OP_NAME "contraction_bilinear"
#define OP_DESC "CONTRACTION" #define OP_DESC "CONTRACTION+Bilinear"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
using Scale = ck::tensor_operation::element_wise::Scale;
static void print_helper_msg() static void print_helper_msg()
{ {
...@@ -50,45 +45,17 @@ static void print_helper_msg() ...@@ -50,45 +45,17 @@ static void print_helper_msg()
<< "value)\n" << "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n" << "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n" << "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8 and arg9(optional): alpha and beta for bilinear (pass only " << "arg8 and arg9: alpha and beta\n"
<< "alpha for scale)\n" << "arg10 to 15: M0, M1, N0, N1, K0, K1\n"
<< "arg9/10 to 14/15: M0, M1, N0, N1, K0, K1\n" << "arg16 to 31: Strides for A, B, C and D (skip for default)\n"
<< "arg15/16 to 30/31: Strides for A, B, C and D (skip for default)\n"
<< std::endl; << std::endl;
} }
void collect_index_params(char* argv[], int profile_contraction_bilinear(int argc, char* argv[])
std::vector<ck::index_t>& params,
const ck::index_t from,
const ck::index_t num)
{
for(ck::index_t p = from; p < from + num; p++)
params.push_back(std::stoi(argv[p]));
}
// Defualt strides for row-major: {Dim1 * Dim2 * Dim3, Dim2 * Dim3, Dim3, 1}
// Defualt strides for column-major: {Dim1, 1, Dim0 * Dim1 * Dim3, Dim0 * Dim1}
void assign_default_strides(Row, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{
strides = {dims[1] * dims[2] * dims[3], dims[2] * dims[3], dims[3], 1};
}
void assign_default_strides(Col, std::vector<ck::index_t>& strides, std::vector<ck::index_t> dims)
{ {
strides = {dims[1], 1, dims[0] * dims[1] * dims[3], dims[0] * dims[1]}; const bool default_strides = argc == 16;
}
int profile_contraction(int argc, char* argv[]) if(argc != 32 && argc != 16)
{
const bool all_parameters_bilinear = argc == 32;
const bool all_parameters_scale = argc == 31;
const bool parameters_wo_strides_bilinear = argc == 16;
const bool parameters_wo_strides_scale = argc == 15;
const bool default_strides = parameters_wo_strides_bilinear || parameters_wo_strides_scale;
const bool with_bilinear = all_parameters_bilinear || parameters_wo_strides_bilinear;
if(!(all_parameters_bilinear || all_parameters_scale || parameters_wo_strides_bilinear ||
parameters_wo_strides_scale))
{ {
print_helper_msg(); print_helper_msg();
exit(1); exit(1);
...@@ -101,12 +68,12 @@ int profile_contraction(int argc, char* argv[]) ...@@ -101,12 +68,12 @@ int profile_contraction(int argc, char* argv[])
const bool do_log = std::stoi(argv[6]); const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]); const bool time_kernel = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]); const float alpha = std::stof(argv[8]);
const float beta = with_bilinear ? std::stof(argv[9]) : 0; const float beta = std::stof(argv[9]);
std::vector<ck::index_t> M; std::vector<ck::index_t> M;
std::vector<ck::index_t> N; std::vector<ck::index_t> N;
std::vector<ck::index_t> K; std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = with_bilinear ? 10 : 9; const ck::index_t dims_arg_num = 10;
collect_index_params(argv, M, dims_arg_num, 2); collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2); collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2); collect_index_params(argv, K, dims_arg_num + 4, 2);
...@@ -140,44 +107,23 @@ int profile_contraction(int argc, char* argv[]) ...@@ -140,44 +107,23 @@ int profile_contraction(int argc, char* argv[])
assign_default_strides(cd_layout, StridesC, {M[0], M[1], N[0], N[1]}); assign_default_strides(cd_layout, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesD, {M[0], M[1], N[0], N[1]}); assign_default_strides(cd_layout, StridesD, {M[0], M[1], N[0], N[1]});
} }
bool pass; bool pass = ck::profiler::profile_contraction_impl<ALayout,
if(with_bilinear) BLayout,
{ CDLayout,
pass = ck::profiler::profile_contraction_impl<ALayout, DataType,
BLayout, ck::Tuple<DataType>,
CDLayout, Bilinear>(do_verification,
DataType, init_method,
ck::Tuple<DataType>, do_log,
Bilinear>(do_verification, time_kernel,
init_method, Bilinear{alpha, beta},
do_log, M,
time_kernel, N,
Bilinear{alpha, beta}, K,
M, StridesA,
N, StridesB,
K, StridesC,
StridesA, StridesD);
StridesB,
StridesC,
StridesD);
}
else
{
pass = ck::profiler::
profile_contraction_impl<ALayout, BLayout, CDLayout, DataType, ck::Tuple<>, Scale>(
do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesC,
StridesD);
}
return pass; return pass;
}; };
...@@ -230,4 +176,4 @@ int profile_contraction(int argc, char* argv[]) ...@@ -230,4 +176,4 @@ int profile_contraction(int argc, char* argv[])
} }
} }
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction); REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_bilinear);
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <vector>
#include "profiler/profile_contraction_impl.hpp"
#include "profiler/profile_contraction_utils.hpp"
#include "profiler_operation_registry.hpp"
enum struct ContractionMatrixLayout
{
MK_KN_MN_MN, // 0
MK_NK_MN_MN, // 1
KM_KN_MN_MN, // 2
KM_NK_MN_MN, // 3
};
enum struct ContractionDataType
{
F32_F32_F32_F32, // 0
F64_F64_F64_F64, // 1
};
#define OP_NAME "contraction_scale"
#define OP_DESC "CONTRACTION+Scale"
static void print_helper_msg()
{
std::cout << "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: f64)\n"
<< "arg3: matrix layout (0: A[m0, m1, k0, k1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n"
<< " 1: A[m0, m1, k0, k1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n"
<< " 2: A[k0, k1, m0, m1] * B[k0, k1, n0, n1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1];\n"
<< " 3: A[k0, k1, m0, m1] * B[n0, n1, k0, k1] + "
"D[m0, m1, n0, n1] = C[m0, m1, n0, n1])\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal "
<< "value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< "arg8: alpha\n"
<< "arg9 to 14: M0, M1, N0, N1, K0, K1\n"
<< "arg15 to 30: Strides for A, B, C and D (skip for default)\n"
<< std::endl;
}
int profile_contraction_scale(int argc, char* argv[])
{
const bool default_strides = argc == 15;
if(argc != 31 && argc != 15)
{
print_helper_msg();
exit(1);
}
const auto data_type = static_cast<ContractionDataType>(std::stoi(argv[2]));
const auto layout = static_cast<ContractionMatrixLayout>(std::stoi(argv[3]));
const bool do_verification = std::stoi(argv[4]);
const ck::index_t init_method = std::stoi(argv[5]);
const bool do_log = std::stoi(argv[6]);
const bool time_kernel = std::stoi(argv[7]);
const float alpha = std::stof(argv[8]);
std::vector<ck::index_t> M;
std::vector<ck::index_t> N;
std::vector<ck::index_t> K;
const ck::index_t dims_arg_num = 9;
collect_index_params(argv, M, dims_arg_num, 2);
collect_index_params(argv, N, dims_arg_num + 2, 2);
collect_index_params(argv, K, dims_arg_num + 4, 2);
std::vector<ck::index_t> StridesA;
std::vector<ck::index_t> StridesB;
std::vector<ck::index_t> StridesC;
std::vector<ck::index_t> StridesD;
if(!default_strides)
{
collect_index_params(argv, StridesA, dims_arg_num + 6, 4);
collect_index_params(argv, StridesB, dims_arg_num + 10, 4);
collect_index_params(argv, StridesC, dims_arg_num + 14, 4);
collect_index_params(argv, StridesD, dims_arg_num + 18, 4);
}
using F32 = float;
using F64 = double;
auto profile = [&](auto a_layout, auto b_layout, auto cd_layout, auto type) {
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CDLayout = decltype(cd_layout);
using DataType = decltype(type);
if(default_strides)
{
assign_default_strides(a_layout, StridesA, {M[0], M[1], K[0], K[1]});
assign_default_strides(b_layout, StridesB, {K[0], K[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesC, {M[0], M[1], N[0], N[1]});
assign_default_strides(cd_layout, StridesD, {M[0], M[1], N[0], N[1]});
}
bool pass = ck::profiler::
profile_contraction_impl<ALayout, BLayout, CDLayout, DataType, ck::Tuple<>, Scale>(
do_verification,
init_method,
do_log,
time_kernel,
Scale{alpha},
M,
N,
K,
StridesA,
StridesB,
StridesC,
StridesD);
return pass;
};
if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F32_F32_F32_F32 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F32{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_KN_MN_MN)
{
return profile(Row{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::MK_NK_MN_MN)
{
return profile(Row{}, Col{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_KN_MN_MN)
{
return profile(Col{}, Row{}, Row{}, F64{});
}
else if(data_type == ContractionDataType::F64_F64_F64_F64 &&
layout == ContractionMatrixLayout::KM_NK_MN_MN)
{
return profile(Col{}, Col{}, Row{}, F64{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_contraction_scale);
...@@ -91,3 +91,8 @@ TEST(TestContractionInterface, IncorrectDataTypes) ...@@ -91,3 +91,8 @@ TEST(TestContractionInterface, IncorrectDataTypes)
EXPECT_FALSE(wrapper_1.IsSupported()); EXPECT_FALSE(wrapper_1.IsSupported());
EXPECT_FALSE(wrapper_2.IsSupported()); EXPECT_FALSE(wrapper_2.IsSupported());
} }
// TEST(TestContractionInterface, CornerCases)
// {
// EXPECT_FALSE()
// }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment