Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7c686fc2
Commit
7c686fc2
authored
Jul 04, 2023
by
ltqin
Browse files
remove useless parameter
parent
2416ddf7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
0 additions
and
12 deletions
+0
-12
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
+0
-4
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
...pl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
+0
-4
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
+0
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
...dwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
+0
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v1.hpp
View file @
7c686fc2
...
...
@@ -117,7 +117,6 @@ __global__ void
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DDataType
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
...
...
@@ -185,7 +184,6 @@ __global__ void
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -221,7 +219,6 @@ __global__ void
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -1105,7 +1102,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_d_grid_
,
arg
.
p_ygrad_grid_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_mha_bwd_xdl_cshuffle_qloop_light_v2.hpp
View file @
7c686fc2
...
...
@@ -116,7 +116,6 @@ __global__ void
const
InputDataType
*
__restrict__
p_b_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_b1_grid
,
const
InputDataType
*
__restrict__
p_c_grid
,
const
LSEDataType
*
__restrict__
p_lse_grid
,
const
DDataType
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
...
...
@@ -184,7 +183,6 @@ __global__ void
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -220,7 +218,6 @@ __global__ void
p_b_grid
+
b_batch_offset
,
z_matrix_ptr
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_lse_grid
+
lse_batch_offset
,
p_d_grid
+
lse_batch_offset
,
p_ygrad_grid
+
c_batch_offset
,
...
...
@@ -1122,7 +1119,6 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
arg
.
p_b_grid_
,
arg
.
p_z_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
p_lse_grid_
,
arg
.
p_d_grid_
,
arg
.
p_ygrad_grid_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v1.hpp
View file @
7c686fc2
...
...
@@ -1231,7 +1231,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
InputDataType
*
__restrict__
p_k_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatD
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
...
...
@@ -1262,7 +1261,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
{
ignore
=
p_y_grid
;
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_mha_bwd_xdl_cshuffle_qloop_b2t_light_v2.hpp
View file @
7c686fc2
...
...
@@ -1163,7 +1163,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
InputDataType
*
__restrict__
p_k_grid
,
ZDataType
*
__restrict__
p_z_grid
,
const
InputDataType
*
__restrict__
p_v_grid
,
const
InputDataType
*
__restrict__
p_y_grid
,
const
FloatLSE
*
__restrict__
p_lse_grid
,
const
FloatD
*
__restrict__
p_d_grid
,
const
InputDataType
*
__restrict__
p_ygrad_grid
,
...
...
@@ -1194,7 +1193,6 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
index_t
raw_n_padded
,
const
index_t
block_idx_n
)
{
ignore
=
p_y_grid
;
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
rp_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
/
p_dropout
);
const
ushort
p_dropout_in_16bits
=
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment