Commit 540d43b4 authored by root's avatar root
Browse files

Merge branch 'aosewski/ggemm_splitk' of...

Merge branch 'aosewski/ggemm_splitk' of https://github.com/ROCmSoftwarePlatform/composable_kernel into aosewski/ggemm_splitk
parents 923b7982 0b8319bd
...@@ -9,10 +9,10 @@ namespace device { ...@@ -9,10 +9,10 @@ namespace device {
namespace instance { namespace instance {
void add_device_conv2d_xdl_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances( ...@@ -22,19 +22,28 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
Mul_Clamp>>>& instances) Mul_Clamp>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Mul_Clamp, Mul_Clamp,
ConvFwdDefault, ConvFwdDefault,
16>{}); 16>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Mul_Clamp, Mul_Clamp,
ConvFwd1x1P0, ConvFwd1x1P0,
16>{}); 16>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Mul_Clamp, Mul_Clamp,
ConvFwd1x1S1P0, ConvFwd1x1S1P0,
...@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances( ...@@ -43,10 +52,10 @@ void add_device_conv2d_xdl_perlayer_quantization_int8_instances(
void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial, std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleD<NDimSpatial,
GNHWC, NHWGC,
GKYXC, GKYXC,
Empty_Tuple, Empty_Tuple,
GNHWK, NHWGK,
int8_t, int8_t,
int8_t, int8_t,
Empty_Tuple, Empty_Tuple,
...@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances( ...@@ -56,19 +65,28 @@ void add_device_conv2d_xdl_relu_perlayer_quantization_int8_instances(
Relu_Mul_Clamp>>>& instances) Relu_Mul_Clamp>>>& instances)
{ {
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Relu_Mul_Clamp, Relu_Mul_Clamp,
ConvFwdDefault, ConvFwdDefault,
16>{}); 16>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Relu_Mul_Clamp, Relu_Mul_Clamp,
ConvFwd1x1P0, ConvFwd1x1P0,
16>{}); 16>{});
add_device_operation_instances(instances, add_device_operation_instances(instances,
device_grouped_conv2d_xdl_int8_instances<Empty_Tuple, device_grouped_conv2d_xdl_int8_instances<NHWGC,
GKYXC,
Empty_Tuple,
NHWGK,
Empty_Tuple, Empty_Tuple,
Relu_Mul_Clamp, Relu_Mul_Clamp,
ConvFwd1x1S1P0, ConvFwd1x1S1P0,
......
...@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification, ...@@ -72,8 +72,8 @@ bool profile_gemm_splitk_impl(int do_verification,
{ {
case 0: break; case 0: break;
case 1: case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5}); a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{0, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5}); b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 1});
break; break;
default: default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment