Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4511f877
"src/include/gridwise_winograd_convolution.cuh" did not exist on "dbffe05a989179d027cdbd3c2a2952a69a44a98e"
Commit
4511f877
authored
May 09, 2022
by
Chao Liu
Browse files
refactor profiler
parent
519b6aaf
Changes
69
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
434 additions
and
739 deletions
+434
-739
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+154
-252
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+162
-300
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+3
-3
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+1
-1
library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt
...tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt
+4
-4
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
...device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
...device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
+66
-0
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
...device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
...device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
..._batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
+0
-66
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
...stance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
+0
-11
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
...ffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
+0
-69
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
...ary/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
+4
-12
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
.../device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
+3
-3
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
.../tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
+19
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
...mm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
+0
-0
library/src/tensor_operation_instance/gpu/gemm_splitk/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
...mm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
+0
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
4511f877
...
@@ -156,6 +156,8 @@ template <typename ADataType,
...
@@ -156,6 +156,8 @@ template <typename ADataType,
struct
DeviceGemmXdlSplitK
struct
DeviceGemmXdlSplitK
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGemmXdlSplitK
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -426,7 +428,6 @@ struct DeviceGemmXdlSplitK
...
@@ -426,7 +428,6 @@ struct DeviceGemmXdlSplitK
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
BatchCount_
(
k_batch
),
BatchCount_
(
k_batch
),
has_tail_
(
false
),
compute_ptr_offset_of_batch_
{
0
,
0
},
compute_ptr_offset_of_batch_
{
0
,
0
},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -444,27 +445,15 @@ struct DeviceGemmXdlSplitK
...
@@ -444,27 +445,15 @@ struct DeviceGemmXdlSplitK
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
(
KSplitted
,
N
,
StrideB
);
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
(
KSplitted
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitK
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitK
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
bool
is_valid
=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
);
if
(
K
!=
KSplitted
*
BatchCount_
)
if
(
K
!=
KSplitted
*
BatchCount_
)
{
{
has_tail_
=
true
;
const
auto
KTail
=
K
-
KSplitted
*
(
BatchCount_
-
1
);
const
auto
KTail
=
K
-
KSplitted
*
(
BatchCount_
-
1
);
a_grid_desc_k0_m_k1_tail_
=
a_grid_desc_k0_m_k1_tail_
=
DeviceGemmXdlSplitK
::
MakeAGridDescriptor_K0_M_K1
(
M
,
KTail
,
StrideA
);
DeviceGemmXdlSplitK
::
MakeAGridDescriptor_K0_M_K1
(
M
,
KTail
,
StrideA
);
b_grid_desc_k0_n_k1_tail_
=
b_grid_desc_k0_n_k1_tail_
=
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
(
KTail
,
N
,
StrideB
);
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
(
KTail
,
N
,
StrideB
);
is_valid
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_tail_
,
b_grid_desc_k0_n_k1_tail_
,
c_grid_desc_m_n_
,
M01_
,
N01_
);
}
}
if
(
is_valid
)
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
...
@@ -494,8 +483,12 @@ struct DeviceGemmXdlSplitK
...
@@ -494,8 +483,12 @@ struct DeviceGemmXdlSplitK
compute_ptr_offset_of_batch_
=
compute_ptr_offset_of_batch_
=
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
block_2_ctile_map_
=
MakeBlock2CTileMap
(
BatchCount_
,
c_grid_desc_m_n_
,
M01
,
N01
);
}
block_2_ctile_map_
=
MakeBlock2CTileMap
(
BatchCount_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
M01
,
N01
);
}
}
// private:
// private:
...
@@ -503,7 +496,6 @@ struct DeviceGemmXdlSplitK
...
@@ -503,7 +496,6 @@ struct DeviceGemmXdlSplitK
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
BatchCount_
;
index_t
BatchCount_
;
bool
has_tail_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_tail_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_tail_
;
...
@@ -526,6 +518,11 @@ struct DeviceGemmXdlSplitK
...
@@ -526,6 +518,11 @@ struct DeviceGemmXdlSplitK
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! not supported"
);
}
{
{
std
::
cout
<<
"k_batch = "
<<
arg
.
BatchCount_
<<
"
\n
"
;
std
::
cout
<<
"k_batch = "
<<
arg
.
BatchCount_
<<
"
\n
"
;
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
...
@@ -539,8 +536,6 @@ struct DeviceGemmXdlSplitK
...
@@ -539,8 +536,6 @@ struct DeviceGemmXdlSplitK
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{"
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
arg
.
has_tail_
)
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_tail_{"
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_tail_{"
<<
arg
.
a_grid_desc_k0_m_k1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_tail_
.
GetLength
(
I1
)
<<
", "
...
@@ -551,35 +546,12 @@ struct DeviceGemmXdlSplitK
...
@@ -551,35 +546,12 @@ struct DeviceGemmXdlSplitK
<<
arg
.
b_grid_desc_k0_n_k1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_k0_n_k1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
}
}
}
bool
is_valid
=
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
if
(
arg
.
has_tail_
)
{
is_valid
&=
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
if
(
!
is_valid
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
arg
.
has_tail_
)
{
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K0_tail
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K0_tail
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
...
@@ -684,81 +656,6 @@ struct DeviceGemmXdlSplitK
...
@@ -684,81 +656,6 @@ struct DeviceGemmXdlSplitK
ave_time
=
Run
(
kernel
);
ave_time
=
Run
(
kernel
);
}
}
}
else
{
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
ck
::
kernel_gemm_xdl_splitk
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
ck
::
kernel_gemm_xdl_splitk
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
return
ave_time
;
}
}
...
@@ -782,6 +679,11 @@ struct DeviceGemmXdlSplitK
...
@@ -782,6 +679,11 @@ struct DeviceGemmXdlSplitK
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
M01_
,
arg
.
N01_
)
&&
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
arg
.
N01_
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
4511f877
...
@@ -167,15 +167,12 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -167,15 +167,12 @@ struct DeviceGemmXdlSplitKCShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
template
<
index_t
K1
>
static
auto
GetActualBatchAndKSplitted
(
index_t
KRaw
,
index_t
KBatch
)
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
{
{
const
index_t
K0PerBlock
=
KPerBlock
/
K1
;
const
index_t
KSplitted
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
*
KBatch
)
*
KPerBlock
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
actual_k_batch
=
math
::
integer_divide_ceil
(
KRaw
,
KSplitted
);
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
return
std
::
make_pair
(
actual_
k_
batch
,
KSplitted
);
}
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
...
@@ -426,7 +423,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -426,7 +423,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
private:
private:
// TODO: should
they b
e long_index_t?
// TODO: should
we us
e long_index_t?
index_t
BatchStrideA_
;
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideB_
;
};
};
...
@@ -502,81 +499,61 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -502,81 +499,61 @@ struct DeviceGemmXdlSplitKCShuffle
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
const
auto
actual_batch_and_ksplitted_A
=
const
auto
actual_batch_and_ksplitted
=
GetActualBatchAndKSplitted
(
KRaw
,
k_batch
);
GetActualBatchAndKSplitted
<
AK1
>
(
KRaw
,
k_batch
);
const
auto
actual_batch_and_ksplitted_B
=
GetActualBatchAndKSplitted
<
BK1
>
(
KRaw
,
k_batch
);
assert
(
actual_batch_and_ksplitted_A
.
first
=
=
actual_batch_and_ksplitted
_B
.
first
)
;
BatchCount_
=
actual_batch_and_ksplitted
.
first
;
BatchCount_
=
actual_batch_and_ksplitted_A
.
first
;
const
auto
KSplitted
=
actual_batch_and_ksplitted
.
second
;
const
auto
AKSplitted
=
actual_batch_and_ksplitted_A
.
second
;
const
auto
BKSplitted
=
actual_batch_and_ksplitted_B
.
second
;
a_grid_desc_ak0_m_ak1_
=
a_grid_desc_ak0_m_ak1_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
A
KSplitted
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KSplitted
,
StrideA
);
b_grid_desc_bk0_n_bk1_
=
b_grid_desc_bk0_n_bk1_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
B
KSplitted
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KSplitted
,
NRaw
,
StrideB
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
is_valid_
=
GridwiseGemm
::
CheckValidity
(
if
(
KRaw
!=
KSplitted
*
BatchCount_
||
KRaw
!=
KSplitted
*
BatchCount_
)
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
);
if
(
KRaw
!=
AKSplitted
*
BatchCount_
||
KRaw
!=
BKSplitted
*
BatchCount_
)
{
{
has_tail_
=
true
;
const
auto
KTail
=
KRaw
-
KSplitted
*
(
BatchCount_
-
1
);
const
auto
AKTail
=
KRaw
-
AKSplitted
*
(
BatchCount_
-
1
);
const
auto
BKTail
=
KRaw
-
BKSplitted
*
(
BatchCount_
-
1
);
a_grid_desc_ak0_m_ak1_tail_
=
a_grid_desc_ak0_m_ak1_tail_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
A
KTail
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KTail
,
StrideA
);
b_grid_desc_bk0_n_bk1_tail_
=
b_grid_desc_bk0_n_bk1_tail_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BKTail
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KTail
,
NRaw
,
StrideB
);
is_valid_
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_tail_
,
b_grid_desc_bk0_n_bk1_tail_
,
c_grid_desc_m_n_
);
}
}
if
(
is_valid_
)
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
const
index_t
a_batch_stride
=
[
A
KSplitted
,
StrideA
]()
{
const
index_t
a_batch_stride
=
[
KSplitted
,
StrideA
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
ignore
=
StrideA
;
ignore
=
StrideA
;
return
A
KSplitted
;
return
KSplitted
;
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
{
return
A
KSplitted
*
StrideA
;
return
KSplitted
*
StrideA
;
}
}
}();
}();
const
index_t
b_batch_stride
=
[
B
KSplitted
,
StrideB
]()
{
const
index_t
b_batch_stride
=
[
KSplitted
,
StrideB
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
B
KSplitted
*
StrideB
;
return
KSplitted
*
StrideB
;
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
ignore
=
StrideB
;
ignore
=
StrideB
;
return
B
KSplitted
;
return
KSplitted
;
}
}
}();
}();
compute_ptr_offset_of_batch_
=
compute_ptr_offset_of_batch_
=
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
block_2_ctile_map_
=
MakeBlock2CTileMap
(
BatchCount_
,
block_2_ctile_map_
=
MakeBlock2CTileMap
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
BatchCount_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
),
1
,
1
);
c_grid_desc_m_n_
.
GetLength
(
I1
),
1
,
1
);
}
}
}
// private:
// private:
...
@@ -584,8 +561,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -584,8 +561,6 @@ struct DeviceGemmXdlSplitKCShuffle
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
BatchCount_
;
index_t
BatchCount_
;
bool
has_tail_
;
bool
is_valid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_tail_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_tail_
;
...
@@ -607,6 +582,11 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -607,6 +582,11 @@ struct DeviceGemmXdlSplitKCShuffle
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
if
(
!
DeviceOp
::
IsSupportedArgument
(
arg
))
{
throw
std
::
runtime_error
(
"wrong! not supported"
);
}
{
{
std
::
cout
<<
"k_batch = "
<<
arg
.
BatchCount_
<<
"
\n
"
;
std
::
cout
<<
"k_batch = "
<<
arg
.
BatchCount_
<<
"
\n
"
;
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_{"
...
@@ -622,9 +602,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -622,9 +602,6 @@ struct DeviceGemmXdlSplitKCShuffle
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
arg
.
has_tail_
)
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_tail_{"
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_tail_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I1
)
<<
", "
...
@@ -635,13 +612,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -635,13 +612,6 @@ struct DeviceGemmXdlSplitKCShuffle
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
}
}
}
if
(
!
arg
.
is_valid_
)
{
throw
std
::
runtime_error
(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
...
@@ -651,39 +621,12 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -651,39 +621,12 @@ struct DeviceGemmXdlSplitKCShuffle
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
arg
.
has_tail_
)
{
const
auto
K0_tail
=
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
);
const
auto
K0_tail
=
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
);
const
bool
tail_has_main_k0_block_loop
=
const
bool
tail_has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
==
0
)
return
launch_and_time_kernel
(
kernel
,
{
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
return
0.0
f
;
}
else
{
return
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -702,7 +645,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -702,7 +645,6 @@ struct DeviceGemmXdlSplitKCShuffle
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
};
};
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
...
@@ -781,90 +723,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -781,90 +723,7 @@ struct DeviceGemmXdlSplitKCShuffle
ave_time
=
Run
(
kernel
);
ave_time
=
Run
(
kernel
);
}
}
}
else
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
nrepeat
==
0
)
{
launch_kernel
(
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
return
0.0
f
;
}
else
{
return
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
};
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
ck
::
kernel_batched_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
true
>
;
ave_time
=
Run
(
kernel
);
}
else
{
const
auto
kernel
=
ck
::
kernel_batched_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
false
>
;
ave_time
=
Run
(
kernel
);
}
}
return
ave_time
;
return
ave_time
;
}
}
...
@@ -884,7 +743,10 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -884,7 +743,10 @@ struct DeviceGemmXdlSplitKCShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
);
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
)
&&
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_m_n_
);
}
}
// polymorphic
// polymorphic
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
4511f877
...
@@ -62,8 +62,8 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -62,8 +62,8 @@ struct ReferenceGemm : public device::BaseOperator
float
v_a
;
float
v_a
;
float
v_b
;
float
v_b
;
arg
.
a_element_op_
(
v_a
,
static_cas
t
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
a_element_op_
(
v_a
,
ck
::
type_conver
t
<
const
float
>
(
arg
.
a_m_k_
(
m
,
k
)));
arg
.
b_element_op_
(
v_b
,
static_cas
t
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
arg
.
b_element_op_
(
v_b
,
ck
::
type_conver
t
<
const
float
>
(
arg
.
b_k_n_
(
k
,
n
)));
v_acc
+=
v_a
*
v_b
;
v_acc
+=
v_a
*
v_b
;
}
}
...
@@ -72,7 +72,7 @@ struct ReferenceGemm : public device::BaseOperator
...
@@ -72,7 +72,7 @@ struct ReferenceGemm : public device::BaseOperator
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_element_op_
(
v_c
,
v_acc
);
arg
.
c_m_n_
(
m
,
n
)
=
v_c
;
arg
.
c_m_n_
(
m
,
n
)
=
ck
::
type_convert
<
CDataType
>
(
v_c
)
;
};
};
make_ParallelTensorFunctor
(
make_ParallelTensorFunctor
(
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
4511f877
...
@@ -24,6 +24,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -24,6 +24,7 @@ function(add_instance_library INSTANCE_NAME)
endfunction
(
add_instance_library INSTANCE_NAME
)
endfunction
(
add_instance_library INSTANCE_NAME
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm
)
add_subdirectory
(
gemm_splitk
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias2d
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu
)
add_subdirectory
(
gemm_bias_relu_add
)
add_subdirectory
(
gemm_bias_relu_add
)
...
@@ -34,7 +35,6 @@ add_subdirectory(conv2d_fwd)
...
@@ -34,7 +35,6 @@ add_subdirectory(conv2d_fwd)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv3d_fwd
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_fwd_bias_relu_add
)
add_subdirectory
(
conv2d_fwd_bias_relu_atomic_add
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
conv2d_bwd_data
)
add_subdirectory
(
reduce
)
add_subdirectory
(
reduce
)
add_subdirectory
(
convnd_bwd_data
)
add_subdirectory
(
convnd_bwd_data
)
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/CMakeLists.txt
View file @
4511f877
...
@@ -12,10 +12,10 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
...
@@ -12,10 +12,10 @@ set(DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_f32_f32_f32_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instance.cpp;
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gnk_gmn_instance.cpp;
device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instance.cpp;
)
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_batched_gemm_instance SHARED
${
DEVICE_BATCHED_GEMM_INSTANCE_SOURCE
}
)
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gkn_gmn_instance.cpp
→
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instance.cpp
View file @
4511f877
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gkn_gmn_instances
=
std
::
tuple
<
using
device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...
@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
...
@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gkm_gkn_gmn_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gkm_gkn_gmn_instances
{});
device_batched_gemm_xdl_i8_i8_i8_gkm_gkn_gmn_instances
{});
}
}
}
// namespace device_batched_gemm_instance
}
// namespace device_batched_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instance.cpp
0 → 100644
View file @
4511f877
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_batched_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
AData
=
int8_t
;
using
BData
=
int8_t
;
using
CData
=
int8_t
;
using
AccData
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
//##################| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##################| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_i8_i8_i8_gkm_gnk_gmn_instances
{});
}
}
// namespace device_batched_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gkn_gmn_instance.cpp
→
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instance.cpp
View file @
4511f877
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gkn_gmn_instances
=
std
::
tuple
<
using
device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...
@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
...
@@ -53,11 +53,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gkn_gmn_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gkn_gmn_instances
(
void
add_device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gkn_gmn_instances
{});
device_batched_gemm_xdl_i8_i8_i8_gmk_gkn_gmn_instances
{});
}
}
}
// namespace device_batched_gemm_instance
}
// namespace device_batched_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gnk_gmn_instance.cpp
→
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instance.cpp
View file @
4511f877
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
...
@@ -23,7 +23,7 @@ using AccData = int32_t;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gnk_gmn_instances
=
std
::
tuple
<
using
device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
...
@@ -45,11 +45,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
...
@@ -45,11 +45,11 @@ using device_batched_gemm_xdl_int8_int8_int8_gmk_gnk_gmn_instances = std::tuple<
// clang-format on
// clang-format on
>
;
>
;
void
add_device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gnk_gmn_instances
(
void
add_device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_i
nt
8_i
nt
8_i
nt
8_gmk_gnk_gmn_instances
{});
device_batched_gemm_xdl_i8_i8_i8_gmk_gnk_gmn_instances
{});
}
}
}
// namespace device_batched_gemm_instance
}
// namespace device_batched_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/batched_gemm/device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instance.cpp
deleted
100644 → 0
View file @
519b6aaf
#include <stdlib.h>
#include "config.hpp"
#include "device_batched_gemm_xdl.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_batched_gemm_instance
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
AData
=
int8_t
;
using
BData
=
int8_t
;
using
CData
=
int8_t
;
using
AccData
=
int32_t
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
=
std
::
tuple
<
// clang-format off
//##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
32
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
64
,
32
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
64
,
32
,
64
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
16
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
256
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
256
,
4
,
16
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
128
,
4
,
16
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
128
,
64
,
128
,
4
,
16
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
128
,
64
,
4
,
16
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
,
DeviceBatchedGemmXdl
<
AData
,
BData
,
CData
,
AccData
,
Col
,
Col
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
256
,
64
,
128
,
4
,
16
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
16
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
true
,
7
,
1
>
// clang-format on
>
;
void
add_device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_xdl_int8_int8_int8_gkm_gnk_gmn_instances
{});
}
}
// namespace device_batched_gemm_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/CMakeLists.txt
deleted
100644 → 0
View file @
519b6aaf
# device_conv2d_fwd_bias_relu_atomic_add_instance
set
(
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp;
)
add_library
(
device_conv2d_fwd_bias_relu_atomic_add_instance SHARED
${
DEVICE_CONV2D_FWD_BIAS_RELU_ATOMIC_ADD_INSTANCE_SOURCE
}
)
target_compile_features
(
device_conv2d_fwd_bias_relu_atomic_add_instance PUBLIC
)
set_target_properties
(
device_conv2d_fwd_bias_relu_atomic_add_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_conv2d_fwd_bias_relu_atomic_add_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_conv2d_fwd_bias_relu_atomic_add_instance
)
library/src/tensor_operation_instance/gpu/conv2d_fwd_bias_relu_atomic_add/device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instance.cpp
deleted
100644 → 0
View file @
519b6aaf
#include <stdlib.h>
#include "config.hpp"
#include "device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp"
#include "element_wise_operation.hpp"
#include "device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
device_conv2d_fwd_bias_activation_atomic_add_instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AddRelu
=
ck
::
tensor_operation
::
element_wise
::
AddRelu
;
static
constexpr
auto
InMemoryAtomicAdd
=
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
;
static
constexpr
auto
ConvFwdDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
using
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//##########################################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| Out| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//##########################################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| GlobalMemory| Specialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//##########################################################################################| | | | | Operation| Operation| Operation| DataOperation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//##########################################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
256
,
128
,
256
,
4
,
8
,
32
,
32
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
128
,
128
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
128
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
128
,
64
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
64
,
64
,
64
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
256
,
128
,
64
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
256
,
64
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
128
,
128
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
8
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
128
,
32
,
128
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
32
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
64
,
64
,
32
,
4
,
8
,
32
,
32
,
2
,
1
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
,
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
<
F16
,
F16
,
F16
,
F32
,
PassThrough
,
PassThrough
,
AddRelu
,
InMemoryAtomicAdd
,
ConvFwdDefault
,
64
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
2
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
1
,
4
,
1
,
1
,
16
>
,
2
>
// clang-format on
>
;
void
add_device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
DeviceConvFwdBiasActivationPtr
<
PassThrough
,
PassThrough
,
AddRelu
>>&
instance_container
)
{
using
Instances
=
device_conv2d_fwd_xdl_c_shuffle_bias_relu_atomic_add_nhwc_kyxc_nhwk_f16_instances
;
const
auto
instances
=
Instances
{};
ck
::
static_for
<
0
,
std
::
tuple_size_v
<
Instances
>
,
1
>
{}([
&
](
auto
i
)
{
using
Instance
=
remove_cvref_t
<
decltype
(
std
::
get
<
i
>
(
instances
))
>
;
auto
instance
=
Instance
{};
instance_container
.
push_back
(
std
::
make_unique
<
Instance
>
(
instance
));
});
}
}
// namespace device_conv2d_fwd_bias_activation_atomic_add_instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm/CMakeLists.txt
View file @
4511f877
...
@@ -8,10 +8,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
...
@@ -8,10 +8,10 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_bf16_bf16_bf16_km_kn_mn_instance.cpp;
...
@@ -25,14 +25,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
...
@@ -25,14 +25,6 @@ set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_c_shuffle_2_stage_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
)
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
add_library
(
device_gemm_instance SHARED
${
DEVICE_GEMM_INSTANCE_SOURCE
}
)
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_kn_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp
View file @
4511f877
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
// Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_kn_mn_instances
=
using
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances =
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instances =
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_kn_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_kn_mn_instances
{});
device_gemm_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances
{});
}
}
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_nk_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp
View file @
4511f877
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
// Compilation parameters for a[k, m] * b[n, k] = c[m, n]
using
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_nk_mn_instances
=
using
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances =
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_km_nk_mn_instances =
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_km_nk_mn_instances
{});
device_gemm_xdl_c_shuffle_i8_i8_i8_km_nk_mn_instances
{});
}
}
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_kn_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp
View file @
4511f877
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
// Compilation parameters for a[m, k] * b[k, n] = c[m, n]
using
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_kn_mn_instances
=
using
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances =
...
@@ -48,11 +48,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instances =
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_kn_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_kn_mn_instances
{});
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_kn_mn_instances
{});
}
}
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_nk_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp
View file @
4511f877
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
...
@@ -22,7 +22,7 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
// Compilation parameters for a[m, k] * b[n, k] = c[m, n]
using
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_nk_mn_instances
=
using
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
=
std
::
tuple
<
std
::
tuple
<
// clang-format off
// clang-format off
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
...
@@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances =
...
@@ -45,11 +45,11 @@ using device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instances =
// clang-format on
// clang-format on
>
;
>
;
void
add_device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_nk_mn_instances
(
void
add_device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
(
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
std
::
vector
<
DeviceGemmPtr
<
PassThrough
,
PassThrough
,
PassThrough
>>&
instances
)
{
{
add_device_operation_instances
(
instances
,
add_device_operation_instances
(
instances
,
device_gemm_xdl_c_shuffle_i
nt
8_i
nt
8_i
nt
8_mk_nk_mn_instances
{});
device_gemm_xdl_c_shuffle_i8_i8_i8_mk_nk_mn_instances
{});
}
}
}
// namespace device_gemm_instance
}
// namespace device_gemm_instance
...
...
library/src/tensor_operation_instance/gpu/gemm_splitk/CMakeLists.txt
0 → 100644
View file @
4511f877
# device_gemm_instance
set
(
DEVICE_GEMM_SPLITK_INSTANCE_SOURCE
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp;
)
add_library
(
device_gemm_splitk_instance SHARED
${
DEVICE_GEMM_SPLITK_INSTANCE_SOURCE
}
)
target_compile_features
(
device_gemm_splitk_instance PUBLIC
)
set_target_properties
(
device_gemm_splitk_instance PROPERTIES POSITION_INDEPENDENT_CODE ON
)
install
(
TARGETS device_gemm_splitk_instance LIBRARY DESTINATION lib
)
clang_tidy_check
(
device_gemm_splitk_instance
)
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm
_splitk
/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instance.cpp
View file @
4511f877
File moved
library/src/tensor_operation_instance/gpu/gemm/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
→
library/src/tensor_operation_instance/gpu/gemm
_splitk
/device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instance.cpp
View file @
4511f877
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment