Unverified Commit d25fcb3d authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

Merge branch 'develop' into navi3x_add_vectorload_check

parents 270dc0a3 7613c1d9
......@@ -9,10 +9,10 @@ namespace device {
namespace instance {
void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
......@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Mul_Clamp,
ConvFwdDefault,
16>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Mul_Clamp,
ConvFwd1x1P0,
16>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Mul_Clamp,
ConvFwd1x1S1P0,
......@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC,
NHWGC,
GKYXC,
Empty_Tuple,
GNHWK,
NHWGK,
int8_t,
int8_t,
Empty_Tuple,
......@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
Relu_Mul_Clamp>>>& instances)
{
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Relu_Mul_Clamp,
ConvFwdDefault,
16>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Relu_Mul_Clamp,
ConvFwd1x1P0,
16>{});
add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple,
device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple,
Relu_Mul_Clamp,
ConvFwd1x1S1P0,
......
......@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification,
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{0, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
......@@ -8,6 +8,7 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp"
......@@ -39,7 +40,8 @@ bool profile_grouped_gemm_impl(int do_verification,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs)
const std::vector<int>& StrideCs,
int kbatch = 1)
{
bool pass = true;
......@@ -96,8 +98,6 @@ bool profile_grouped_gemm_impl(int do_verification,
a_m_k[i].GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}, num_thread);
b_k_n[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}, num_thread);
}
c_m_n_device_results[i].GenerateTensorValue(GeneratorTensor_0<CDataType>{}, num_thread);
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
......@@ -132,13 +132,12 @@ bool profile_grouped_gemm_impl(int do_verification,
std::make_unique<DeviceMem>(sizeof(ADataType) * a_m_k[i].mDesc.GetElementSpaceSize()));
b_device_buf.emplace_back(
std::make_unique<DeviceMem>(sizeof(BDataType) * b_k_n[i].mDesc.GetElementSpaceSize()));
c_device_buf.emplace_back(std::make_unique<DeviceMem>(
sizeof(CDataType) * c_m_n_device_results[i].mDesc.GetElementSpaceSize()));
a_device_buf[i]->ToDevice(a_m_k[i].mData.data());
b_device_buf[i]->ToDevice(b_k_n[i].mData.data());
c_device_buf[i]->ToDevice(c_m_n_device_results[i].mData.data());
c_device_buf[i]->SetZero();
gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}});
......@@ -197,6 +196,28 @@ bool profile_grouped_gemm_impl(int do_verification,
{
std::string gemm_name = gemm_ptr->GetTypeString();
if(kbatch > 1)
{
using DeviceOpSplitK =
ck::tensor_operation::device::DeviceGroupedGemmSplitK<ALayout,
BLayout,
ck::Tuple<>,
CLayout,
ADataType,
BDataType,
ck::Tuple<>,
CDataType,
AElementOp,
BElementOp,
CElementOp>;
if(dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get()) != nullptr)
{
dynamic_cast<DeviceOpSplitK*>(gemm_ptr.get())
->SetKBatchSize(argument_ptr.get(), kbatch);
}
}
float ave_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
......
......@@ -52,20 +52,24 @@ std::vector<int> argToIntArray(char* input)
int profile_grouped_gemm(int argc, char* argv[])
{
if(!(argc == 14))
if(argc < 14)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n];\n");
printf(" 2: A[k, m] * B[k, n] = C[m, n];\n");
printf(" 3: A[k, m] * B[n, k] = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=n0, 1=yes)\n");
printf("arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n");
std::cout
<< "arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"
<< "arg2: data type (0: fp32; 1: fp16; 2: bf16; 3: int8)\n"
<< "arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n];\n"
<< " 1: A[m, k] * B[n, k] = C[m, n];\n"
<< " 2: A[k, m] * B[k, n] = C[m, n];\n"
<< " 3: A[k, m] * B[n, k] = C[m, n])\n"
<< "arg4: verification (0: no; 1: yes)\n"
<< "arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n"
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0=n0, 1=yes)\n"
<< "arg8 to 13: Ms, Ns, Ks, StrideAs, StrideBs, StrideCs (e.g., 256,256 128,128 64,64 "
"64,64 64,64 128,128)\n"
<< "arg15: kbatch value (default 4)\n"
<< std::endl;
exit(1);
}
......@@ -83,6 +87,7 @@ int profile_grouped_gemm(int argc, char* argv[])
const auto StrideAs = argToIntArray(argv[11]);
const auto StrideBs = argToIntArray(argv[12]);
const auto StrideCs = argToIntArray(argv[13]);
const int kbatch = argc == 15 ? std::stoi(argv[14]) : 1;
if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_KN_MN)
{
......@@ -101,7 +106,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks,
StrideAs,
StrideBs,
StrideCs);
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
......@@ -120,7 +126,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks,
StrideAs,
StrideBs,
StrideCs);
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_KN_MN)
{
......@@ -139,7 +146,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks,
StrideAs,
StrideBs,
StrideCs);
StrideCs,
kbatch);
}
else if(data_type == GemmDataType::F16_F16_F16 && layout == GemmMatrixLayout::KM_NK_MN)
{
......@@ -158,7 +166,8 @@ int profile_grouped_gemm(int argc, char* argv[])
Ks,
StrideAs,
StrideBs,
StrideCs);
StrideCs,
kbatch);
}
else
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment