Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6d7a5784
"driver/src/driver.cpp" did not exist on "87d8740bf5d8030f9e4e54c9b7e64f353a6f944e"
Commit
6d7a5784
authored
Feb 15, 2023
by
ltqin
Browse files
add is_dropout parameter to gridwise
parent
63b37aa0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
19 deletions
+33
-19
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
...ice_batched_multihead_attention_backward_xdl_cshuffle.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
+24
-19
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle.hpp
View file @
6d7a5784
...
@@ -83,6 +83,7 @@ __global__ void
...
@@ -83,6 +83,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
,
const
C0MatrixMask
c0_matrix_mask
,
const
float
p_dropout
,
const
float
p_dropout
,
const
bool
is_dropout
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
)
const
unsigned
long
long
offset
)
{
{
...
@@ -138,6 +139,7 @@ __global__ void
...
@@ -138,6 +139,7 @@ __global__ void
block_2_ctile_map
,
block_2_ctile_map
,
c0_matrix_mask
,
c0_matrix_mask
,
p_dropout
,
p_dropout
,
is_dropout
,
ph
);
ph
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
...
@@ -157,6 +159,10 @@ __global__ void
...
@@ -157,6 +159,10 @@ __global__ void
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
ignore
=
p_dropout
;
ignore
=
is_dropout
;
ignore
=
seed
;
ignore
=
offset
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -778,6 +784,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -778,6 +784,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
}
}
p_dropout_
=
1.
f
-
p_drop
;
p_dropout_
=
1.
f
-
p_drop
;
is_dropout_
=
p_drop
>
0.0
f
;
float
rp_dropout_
=
1.
f
/
p_dropout_
;
float
rp_dropout_
=
1.
f
/
p_dropout_
;
acc_element_op_
.
Append
(
rp_dropout_
);
acc_element_op_
.
Append
(
rp_dropout_
);
...
@@ -872,6 +879,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -872,6 +879,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
float
p_dropout_
;
float
p_dropout_
;
bool
is_dropout_
;
unsigned
long
long
seed_
;
unsigned
long
long
seed_
;
unsigned
long
long
offset_
;
unsigned
long
long
offset_
;
};
};
...
@@ -954,6 +962,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
...
@@ -954,6 +962,7 @@ struct DeviceBatchedMultiheadAttentionBackward_Xdl_CShuffle
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
c0_matrix_mask_
,
arg
.
c0_matrix_mask_
,
arg
.
p_dropout_
,
arg
.
p_dropout_
,
arg
.
is_dropout_
,
arg
.
seed_
,
arg
.
seed_
,
arg
.
offset_
);
arg
.
offset_
);
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_v1.hpp
View file @
6d7a5784
...
@@ -1170,6 +1170,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1170,6 +1170,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
const
Block2CTileMap
&
block_2_ctile_map
,
const
Block2CTileMap
&
block_2_ctile_map
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
FloatGemmAcc
p_dropout
,
FloatGemmAcc
p_dropout
,
const
bool
is_dropout
,
ck
::
philox
&
ph
)
ck
::
philox
&
ph
)
{
{
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
const
ushort
p_dropout_in_16bits
=
uint16_t
(
std
::
floor
(
p_dropout
*
65535.0
));
...
@@ -1847,6 +1848,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1847,6 +1848,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
blockwise_softmax
.
RunWithPreCalcStats
(
s_slash_p_thread_buf
,
lse_thread_buf
);
// save z to global
// save z to global
if
(
is_dropout
)
{
if
(
p_z_grid
)
if
(
p_z_grid
)
{
{
// P_dropped
// P_dropped
...
@@ -1855,7 +1858,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1855,7 +1858,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
s_slash_p_thread_buf
,
ph
,
z_tenor_buffer
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
z_tenor_buffer
,
z_tenor_buffer
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
...
@@ -1867,6 +1871,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
...
@@ -1867,6 +1871,7 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V2
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
blockwise_dropout
.
template
ApplyDropout
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
);
s_slash_p_thread_buf
,
ph
);
}
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment