Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0a929502
Commit
0a929502
authored
May 07, 2023
by
Po-Yen, Chen
Browse files
Remove tailing underscore in public attribute name
parent
f4ea00fc
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
69 deletions
+69
-69
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+69
-69
No files found.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
0a929502
...
...
@@ -172,12 +172,12 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
{
using
Parent
=
typename
GridwiseGemm
::
Argument
;
Argument
(
const
ADataType
*
p_a_grid_real
,
const
ADataType
*
p_a_grid_imag
,
const
BDataType
*
p_b_grid_real
,
const
BDataType
*
p_b_grid_imag
,
CDataType
*
p_c_grid_real
,
CDataType
*
p_c_grid_imag
,
Argument
(
const
ADataType
*
p_a_grid_real
_
,
const
ADataType
*
p_a_grid_imag
_
,
const
BDataType
*
p_b_grid_real
_
,
const
BDataType
*
p_b_grid_imag
_
,
CDataType
*
p_c_grid_real
_
,
CDataType
*
p_c_grid_imag
_
,
CDataType
*
p_workspace
,
index_t
M_
,
index_t
N_
,
...
...
@@ -196,40 +196,40 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
GridwiseGemm
::
CalculateKPadded
(
K_
),
GridwiseGemm
::
CalculateAK0
(
K_
),
GridwiseGemm
::
CalculateBK0
(
K_
)),
p_a_grid_real
_
{
p_a_grid_real
},
p_a_grid_imag
_
{
p_a_grid_imag
},
p_b_grid_real
_
{
p_b_grid_real
},
p_b_grid_imag
_
{
p_b_grid_imag
},
p_c_grid_real
_
{
p_c_grid_real
},
p_c_grid_imag
_
{
p_c_grid_imag
},
p_aux_grid
_
{
p_workspace
}
p_a_grid_real
{
p_a_grid_real
_
},
p_a_grid_imag
{
p_a_grid_imag
_
},
p_b_grid_real
{
p_b_grid_real
_
},
p_b_grid_imag
{
p_b_grid_imag
_
},
p_c_grid_real
{
p_c_grid_real
_
},
p_c_grid_imag
{
p_c_grid_imag
_
},
p_aux_grid
{
p_workspace
}
{
const
index_t
grid_size
=
std
::
get
<
1
>
(
GridwiseGemm
::
CalculateGridSize
(
M_
,
N_
));
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
_
=
c_grid_desc_m
=
DeviceOp
::
MakeDescriptor_M
({
M_
,
N_
},
{
StrideC_
,
I1
},
grid_size
,
BlockSize
);
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
c_grid_desc_m
_
=
c_grid_desc_m
=
DeviceOp
::
MakeDescriptor_M
({
M_
,
N_
},
{
I1
,
StrideC_
},
grid_size
,
BlockSize
);
}
p_aux_2_grid
_
=
p_workspace
+
GetCElementSpaceSize
(
M_
,
N_
,
StrideC_
);
p_aux_2_grid
=
p_workspace
+
GetCElementSpaceSize
(
M_
,
N_
,
StrideC_
);
}
// private:
const
ADataType
*
p_a_grid_real
_
;
const
ADataType
*
p_a_grid_imag
_
;
const
BDataType
*
p_b_grid_real
_
;
const
BDataType
*
p_b_grid_imag
_
;
CDataType
*
p_c_grid_real
_
;
CDataType
*
p_c_grid_imag
_
;
CDataType
*
p_aux_grid
_
;
CDataType
*
p_aux_2_grid
_
;
CGridDesc_M
c_grid_desc_m
_
;
const
ADataType
*
p_a_grid_real
;
const
ADataType
*
p_a_grid_imag
;
const
BDataType
*
p_b_grid_real
;
const
BDataType
*
p_b_grid_imag
;
CDataType
*
p_c_grid_real
;
CDataType
*
p_c_grid_imag
;
CDataType
*
p_aux_grid
;
CDataType
*
p_aux_2_grid
;
CGridDesc_M
c_grid_desc_m
;
};
// Invoker
...
...
@@ -303,9 +303,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real
_
,
karg
.
p_b_grid_real
_
,
karg
.
p_aux_grid
_
,
karg
.
p_a_grid_real
,
karg
.
p_b_grid_real
,
karg
.
p_aux_grid
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -313,9 +313,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag
_
,
karg
.
p_b_grid_imag
_
,
karg
.
p_aux_2_grid
_
,
karg
.
p_a_grid_imag
,
karg
.
p_b_grid_imag
,
karg
.
p_aux_2_grid
,
karg
);
// c_real = aux - aux_2
...
...
@@ -325,11 +325,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
karg
.
c_grid_desc_m
_
,
karg
.
c_grid_desc_m
_
),
make_tuple
(
karg
.
c_grid_desc_m
_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
_
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
_
)),
make_tuple
(
karg
.
p_c_grid_real
_
),
make_tuple
(
karg
.
c_grid_desc_m
,
karg
.
c_grid_desc_m
),
make_tuple
(
karg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
)),
make_tuple
(
karg
.
p_c_grid_real
),
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -337,9 +337,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real
_
,
karg
.
p_b_grid_imag
_
,
karg
.
p_aux_grid
_
,
karg
.
p_a_grid_real
,
karg
.
p_b_grid_imag
,
karg
.
p_aux_grid
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -347,9 +347,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag
_
,
karg
.
p_b_grid_real
_
,
karg
.
p_aux_2_grid
_
,
karg
.
p_a_grid_imag
,
karg
.
p_b_grid_real
,
karg
.
p_aux_2_grid
,
karg
);
// c_imag = aux + aux_2
...
...
@@ -359,11 +359,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
karg
.
c_grid_desc_m
_
,
karg
.
c_grid_desc_m
_
),
make_tuple
(
karg
.
c_grid_desc_m
_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
_
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
_
)),
make_tuple
(
karg
.
p_c_grid_imag
_
),
make_tuple
(
karg
.
c_grid_desc_m
,
karg
.
c_grid_desc_m
),
make_tuple
(
karg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
)),
make_tuple
(
karg
.
p_c_grid_imag
),
Add
{});
}
else
...
...
@@ -375,9 +375,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real
_
,
karg
.
p_b_grid_real
_
,
karg
.
p_aux_grid
_
,
karg
.
p_a_grid_real
,
karg
.
p_b_grid_real
,
karg
.
p_aux_grid
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -385,9 +385,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag
_
,
karg
.
p_b_grid_imag
_
,
karg
.
p_aux_2_grid
_
,
karg
.
p_a_grid_imag
,
karg
.
p_b_grid_imag
,
karg
.
p_aux_2_grid
,
karg
);
// c_real = aux - aux_2
...
...
@@ -397,11 +397,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
karg
.
c_grid_desc_m
_
,
karg
.
c_grid_desc_m
_
),
make_tuple
(
karg
.
c_grid_desc_m
_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
_
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
_
)),
make_tuple
(
karg
.
p_c_grid_real
_
),
make_tuple
(
karg
.
c_grid_desc_m
,
karg
.
c_grid_desc_m
),
make_tuple
(
karg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
)),
make_tuple
(
karg
.
p_c_grid_real
),
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -409,9 +409,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_real
_
,
karg
.
p_b_grid_imag
_
,
karg
.
p_aux_grid
_
,
karg
.
p_a_grid_real
,
karg
.
p_b_grid_imag
,
karg
.
p_aux_grid
,
karg
);
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -419,9 +419,9 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
.
p_a_grid_imag
_
,
karg
.
p_b_grid_real
_
,
karg
.
p_aux_2_grid
_
,
karg
.
p_a_grid_imag
,
karg
.
p_b_grid_real
,
karg
.
p_aux_2_grid
,
karg
);
// c_imag = aux + aux_2
...
...
@@ -431,11 +431,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
make_tuple
(
karg
.
c_grid_desc_m
_
,
karg
.
c_grid_desc_m
_
),
make_tuple
(
karg
.
c_grid_desc_m
_
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
_
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
_
)),
make_tuple
(
karg
.
p_c_grid_imag
_
),
make_tuple
(
karg
.
c_grid_desc_m
,
karg
.
c_grid_desc_m
),
make_tuple
(
karg
.
c_grid_desc_m
),
make_tuple
(
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_grid
),
const_cast
<
const
CDataType
*>
(
karg
.
p_aux_2_grid
)),
make_tuple
(
karg
.
p_c_grid_imag
),
Add
{});
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment