Commit 6257e347 authored by Chao Liu's avatar Chao Liu
Browse files

clean

parent ac876c6f
...@@ -18,7 +18,7 @@ enum struct ConvolutionForwardSpecialization ...@@ -18,7 +18,7 @@ enum struct ConvolutionForwardSpecialization
OddC, OddC,
}; };
inline std::string getConvFwdSpecializationStr(const ConvolutionForwardSpecialization& s) inline std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization& s)
{ {
switch(s) switch(s)
{ {
......
...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W ...@@ -871,7 +871,7 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -711,7 +711,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -1033,7 +1033,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< getConvFwdSpecializationStr(ConvForwardSpecialization) << getConvForwardSpecializationString(ConvForwardSpecialization)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -746,7 +746,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout, ...@@ -746,7 +746,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
<< NPerBlock << ", " << NPerBlock << ", "
<< KPerBlock << ", " << KPerBlock << ", "
<< AK1 << ", " << AK1 << ", "
<< BK1 << BK1 << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -19,6 +19,22 @@ enum struct GemmSpecialization ...@@ -19,6 +19,22 @@ enum struct GemmSpecialization
MNKPadding, MNKPadding,
}; };
inline std::string getGemmSpecializationString(const GemmSpecialization& s)
{
switch(s)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
default: return "Unrecognized specialization!";
}
}
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
...@@ -35,7 +35,6 @@ struct Add ...@@ -35,7 +35,6 @@ struct Add
y = type_convert<half_t>(x0) + x1; y = type_convert<half_t>(x0) + x1;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -43,7 +42,6 @@ struct Add ...@@ -43,7 +42,6 @@ struct Add
y = x0 + x1; y = x0 + x1;
}; };
// Question: should bhalf_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
...@@ -74,7 +72,6 @@ struct Subtract ...@@ -74,7 +72,6 @@ struct Subtract
y = x0 - x1; y = x0 - x1;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -82,7 +79,6 @@ struct Subtract ...@@ -82,7 +79,6 @@ struct Subtract
y = x0 - x1; y = x0 - x1;
}; };
// Question: should bhalf_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
...@@ -98,32 +94,6 @@ struct Bilinear ...@@ -98,32 +94,6 @@ struct Bilinear
{ {
Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){}; Bilinear(float alpha, float beta) : alpha_(alpha), beta_(beta){};
#if 0
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = alpha_ * x0 + beta_ * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = type_convert<double>(alpha_) * x0 + type_convert<double>(beta_) * x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = type_convert<half_t>(alpha_ * type_convert<float>(x0) +
beta_ * type_convert<float>(x1));
};
#else
template <typename Y, typename X0, typename X1> template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const; __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
...@@ -140,7 +110,6 @@ struct Bilinear ...@@ -140,7 +110,6 @@ struct Bilinear
{ {
y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1)); y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
}; };
#endif
float alpha_; float alpha_;
float beta_; float beta_;
...@@ -167,7 +136,6 @@ struct AddRelu ...@@ -167,7 +136,6 @@ struct AddRelu
y = a > 0.0 ? a : 0.0; y = a > 0.0 ? a : 0.0;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
...@@ -202,7 +170,6 @@ struct AddHardswish ...@@ -202,7 +170,6 @@ struct AddHardswish
y = c; y = c;
}; };
// Question: should half_t be supported ?
template <> template <>
__host__ __device__ constexpr void __host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
......
...@@ -30,18 +30,18 @@ template <typename ADataType, ...@@ -30,18 +30,18 @@ template <typename ADataType,
typename BLayout, typename BLayout,
typename DELayout> // assume Ds and E have same layout typename DELayout> // assume Ds and E have same layout
bool profile_gemm_bilinear_impl(int do_verification, bool profile_gemm_bilinear_impl(int do_verification,
int init_method, int init_method,
bool /*do_log*/, bool /*do_log*/,
bool time_kernel, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
int StrideA, int StrideA,
int StrideB, int StrideB,
int StrideD, int StrideD,
int StrideE, int StrideE,
float alpha, float alpha,
float beta) float beta)
{ {
auto f_host_tensor_descriptor = auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) { [](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
......
...@@ -29,7 +29,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[]) ...@@ -29,7 +29,7 @@ int profile_gemm_add_add_fastgelu(int argc, char* argv[])
if(argc != 16) if(argc != 16)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+GeLU)\n"); printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+FastGeLU)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n"); printf("arg3: matrix layout (0: E[m, n] = FastGeLU(A[m, k] * B[k, n] + D0[m, n] + D1[m, n]);\n");
printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n"); printf(" 1: E[m, n] = FastGeLU(A[m, k] * B[n, k] + D0[m, n] + D1[m, n]);\n");
......
...@@ -29,7 +29,7 @@ int profile_gemm_bilinear(int argc, char* argv[]) ...@@ -29,7 +29,7 @@ int profile_gemm_bilinear(int argc, char* argv[])
if(argc != 17) if(argc != 17)
{ {
// clang-format off // clang-format off
printf("arg1: tensor operation (gemm_add_add_fastgelu: GEMM+Add+Add+GeLU)\n"); printf("arg1: tensor operation (gemm_bilinear: GEMM+Bilinear)\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"); printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: E[m, n] = alpha * A[m, k] * B[k, n] + beta * D[m, n];\n"); printf("arg3: matrix layout (0: E[m, n] = alpha * A[m, k] * B[k, n] + beta * D[m, n];\n");
printf(" 1: E[m, n] = alpha * A[m, k] * B[n, k] + beta * D[m, n];\n"); printf(" 1: E[m, n] = alpha * A[m, k] * B[n, k] + beta * D[m, n];\n");
...@@ -94,13 +94,13 @@ int profile_gemm_bilinear(int argc, char* argv[]) ...@@ -94,13 +94,13 @@ int profile_gemm_bilinear(int argc, char* argv[])
const int DefaultStrideE = ck::is_same_v<DELayout, Row> ? N : M; const int DefaultStrideE = ck::is_same_v<DELayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_bilinear_impl<ADataType, bool pass = ck::profiler::profile_gemm_bilinear_impl<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
DDataType, DDataType,
EDataType, EDataType,
ALayout, ALayout,
BLayout, BLayout,
DELayout>( DELayout>(
do_verification, do_verification,
init_method, init_method,
do_log, do_log,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment