Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7bae1691
Commit
7bae1691
authored
May 05, 2023
by
Po-Yen, Chen
Browse files
Adapt the new GridwiseGemm interface
parent
ceebf306
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
152 additions
and
266 deletions
+152
-266
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+152
-266
No files found.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
7bae1691
...
@@ -389,6 +389,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -389,6 +389,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -396,10 +399,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -396,10 +399,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -434,8 +435,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -434,8 +435,10 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
LoopSched
>
;
LoopSched
>
;
// Argument
// Argument
struct
Argument
:
public
Base
Argument
struct
Argument
:
public
GridwiseGemm
::
Argument
{
{
using
Parent
=
typename
GridwiseGemm
::
Argument
;
Argument
(
const
ADataType
*
p_a_grid_real
,
Argument
(
const
ADataType
*
p_a_grid_real
,
const
ADataType
*
p_a_grid_imag
,
const
ADataType
*
p_a_grid_imag
,
const
BDataType
*
p_b_grid_real
,
const
BDataType
*
p_b_grid_real
,
...
@@ -443,55 +446,53 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -443,55 +446,53 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_c_grid_imag
,
CDataType
*
p_workspace
,
CDataType
*
p_workspace
,
index_t
MRaw
,
index_t
M_
,
index_t
NRaw
,
index_t
N_
,
index_t
KRaw
,
index_t
K_
,
index_t
StrideA
,
index_t
StrideA_
,
index_t
StrideB
,
index_t
StrideB_
,
index_t
StrideC
,
index_t
StrideC_
,
AElementwiseOperation
a_element_op
,
index_t
MPadded_
,
BElementwiseOperation
b_element_op
,
index_t
NPadded_
,
CElementwiseOperation
c_element_op
)
index_t
KPadded_
,
:
p_a_grid_real_
{
p_a_grid_real
},
index_t
AK0_
,
index_t
BK0_
)
:
Parent
(
nullptr
,
nullptr
,
nullptr
,
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
MPadded_
,
NPadded_
,
KPadded_
,
AK0_
,
BK0_
),
p_a_grid_real_
{
p_a_grid_real
},
p_a_grid_imag_
{
p_a_grid_imag
},
p_a_grid_imag_
{
p_a_grid_imag
},
p_b_grid_real_
{
p_b_grid_real
},
p_b_grid_real_
{
p_b_grid_real
},
p_b_grid_imag_
{
p_b_grid_imag
},
p_b_grid_imag_
{
p_b_grid_imag
},
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_real_
{
p_c_grid_real
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_c_grid_imag_
{
p_c_grid_imag
},
p_aux_grid_
{
p_workspace
},
p_aux_grid_
{
p_workspace
}
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
const
index_t
grid_size
=
std
::
get
<
1
>
(
GridwiseGemm
::
CalculateGridSize
(
M_
,
N_
));
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
const
index_t
grid_size
=
block_2_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n_
);
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
c_grid_desc_m_
=
c_grid_desc_m_
=
DeviceOp
::
MakeDescriptor_M
({
M
Raw
,
NRaw
},
{
StrideC
,
I1
},
grid_size
,
BlockSize
);
DeviceOp
::
MakeDescriptor_M
({
M
_
,
N_
},
{
StrideC
_
,
I1
},
grid_size
,
BlockSize
);
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
c_grid_desc_m_
=
c_grid_desc_m_
=
DeviceOp
::
MakeDescriptor_M
({
M
Raw
,
NRaw
},
{
I1
,
StrideC
},
grid_size
,
BlockSize
);
DeviceOp
::
MakeDescriptor_M
({
M
_
,
N_
},
{
I1
,
StrideC
_
},
grid_size
,
BlockSize
);
}
}
p_aux_2_grid_
=
p_workspace
+
c_grid_desc_m_n
_
.
GetElementSpaceSize
();
p_aux_2_grid_
=
p_workspace
+
Parent
::
c_grid_desc_m_n
.
GetElementSpaceSize
();
}
}
// private:
// private:
...
@@ -503,38 +504,32 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -503,38 +504,32 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_grid_imag_
;
CDataType
*
p_c_grid_imag_
;
CDataType
*
p_aux_grid_
;
CDataType
*
p_aux_grid_
;
CDataType
*
p_aux_2_grid_
;
CDataType
*
p_aux_2_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M
c_grid_desc_m_
;
CGridDesc_M
c_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceOp
::
Argument
;
// void Print(const Argument& karg) { karg.Print(); }
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
Argument
karg
=
arg
;
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
if
(
stream_config
.
log_level_
>
0
)
arg
.
block_2_ctile_map_
))
{
// Print(karg);
}
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
}
const
index_t
g
rid_size
=
index_t
g
dx
,
gdy
,
gdz
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
k
arg
.
M
,
karg
.
N
);
const
auto
K
=
const
auto
K
=
GridwiseGemm
::
CalculateAK0
(
karg
.
K
)
*
AK1
;
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -578,224 +573,114 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -578,224 +573,114 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
true
>
;
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
CDataType
,
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
AElementwiseOperation
,
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
BElementwiseOperation
,
ave_time
+=
launch_and_time_kernel
(
CElementwiseOperation
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_aux_2_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
subtract_kernel
,
subtract_kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
,
k
arg
.
c_grid_desc_m_
),
make_tuple
(
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid_
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
)),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid_
)),
make_tuple
(
arg
.
p_c_grid_real_
),
make_tuple
(
k
arg
.
p_c_grid_real_
),
Subtract
{});
Subtract
{});
ave_time
+=
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
launch_and_time_kernel
(
stream_config
,
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
kernel
,
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
dim3
(
grid_size
),
ave_time
+=
launch_and_time_kernel
(
dim3
(
BlockSize
),
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
0
,
arg
.
p_a_grid_real_
,
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
arg
.
p_b_grid_imag_
,
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
arg
.
p_aux_grid_
,
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
arg
.
a_element_op_
,
ave_time
+=
launch_and_time_kernel
(
arg
.
b_element_op_
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_2_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
add_kernel
,
add_kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
,
k
arg
.
c_grid_desc_m_
),
make_tuple
(
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid_
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
)),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid_
)),
make_tuple
(
arg
.
p_c_grid_imag_
),
make_tuple
(
k
arg
.
p_c_grid_imag_
),
Add
{});
Add
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
false
>
;
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
CDataType
,
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
AElementwiseOperation
,
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
BElementwiseOperation
,
ave_time
+=
launch_and_time_kernel
(
CElementwiseOperation
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_real_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_imag_
,
arg
.
p_aux_2_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
// c_real = aux - aux_2
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
subtract_kernel
,
subtract_kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
,
k
arg
.
c_grid_desc_m_
),
make_tuple
(
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid_
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
)),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid_
)),
make_tuple
(
arg
.
p_c_grid_real_
),
make_tuple
(
k
arg
.
p_c_grid_real_
),
Subtract
{});
Subtract
{});
ave_time
+=
karg
.
p_a_grid
=
karg
.
p_a_grid_real_
;
launch_and_time_kernel
(
stream_config
,
karg
.
p_b_grid
=
karg
.
p_b_grid_imag_
;
kernel
,
karg
.
p_c_grid
=
karg
.
p_aux_grid_
;
dim3
(
grid_size
),
ave_time
+=
launch_and_time_kernel
(
dim3
(
BlockSize
),
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
0
,
arg
.
p_a_grid_real_
,
karg
.
p_a_grid
=
karg
.
p_a_grid_imag_
;
arg
.
p_b_grid_imag_
,
karg
.
p_b_grid
=
karg
.
p_b_grid_real_
;
arg
.
p_aux_grid_
,
karg
.
p_c_grid
=
karg
.
p_aux_2_grid_
;
arg
.
a_element_op_
,
ave_time
+=
launch_and_time_kernel
(
arg
.
b_element_op_
,
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_imag_
,
arg
.
p_b_grid_real_
,
arg
.
p_aux_2_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
ave_time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
add_kernel
,
add_kernel
,
dim3
(
g
rid_size
),
dim3
(
g
dx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
,
k
arg
.
c_grid_desc_m_
),
make_tuple
(
arg
.
c_grid_desc_m_
),
make_tuple
(
k
arg
.
c_grid_desc_m_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_grid_
),
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
)),
const_cast
<
const
CDataType
*>
(
k
arg
.
p_aux_2_grid_
)),
make_tuple
(
arg
.
p_c_grid_imag_
),
make_tuple
(
k
arg
.
p_c_grid_imag_
),
Add
{});
Add
{});
}
}
...
@@ -816,12 +701,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -816,12 +701,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
return
GridwiseGemm
::
CheckValidity
(
karg
);
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
}
// polymorphic
// polymorphic
...
@@ -837,15 +719,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -837,15 +719,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
CDataType
*
p_c_real
,
CDataType
*
p_c_real
,
CDataType
*
p_c_imag
,
CDataType
*
p_c_imag
,
CDataType
*
p_workspace
,
CDataType
*
p_workspace
,
index_t
M
Raw
,
index_t
M
,
index_t
N
Raw
,
index_t
N
,
index_t
K
Raw
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
)
{
{
return
Argument
{
p_a_real
,
return
Argument
{
p_a_real
,
p_a_imag
,
p_a_imag
,
...
@@ -854,15 +736,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -854,15 +736,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
p_c_real
,
p_c_real
,
p_c_imag
,
p_c_imag
,
p_workspace
,
p_workspace
,
M
Raw
,
M
,
N
Raw
,
N
,
K
Raw
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
GridwiseGemm
::
CalculateMPadded
(
M
),
b_element_op
,
GridwiseGemm
::
CalculateNPadded
(
N
),
c_element_op
};
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
)};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -875,15 +759,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -875,15 +759,15 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
void
*
p_c_real
,
void
*
p_c_real
,
void
*
p_c_imag
,
void
*
p_c_imag
,
void
*
p_workspace
,
void
*
p_workspace
,
index_t
M
Raw
,
index_t
M
,
index_t
N
Raw
,
index_t
N
,
index_t
K
Raw
,
index_t
K
,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
,
index_t
/* KBatch */
=
1
)
override
index_t
/* KBatch */
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a_real
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a_real
),
...
@@ -893,15 +777,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -893,15 +777,17 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_real
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_c_imag
),
static_cast
<
CDataType
*>
(
p_workspace
),
static_cast
<
CDataType
*>
(
p_workspace
),
M
Raw
,
M
,
N
Raw
,
N
,
K
Raw
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
GridwiseGemm
::
CalculateMPadded
(
M
),
b_element_op
,
GridwiseGemm
::
CalculateNPadded
(
N
),
c_element_op
);
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
));
}
}
// polymorphic
// polymorphic
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment