Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0413755a
Commit
0413755a
authored
Jun 16, 2023
by
guangzlu
Browse files
modified offest to bwd pt6
parent
092a88bf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
21 deletions
+19
-21
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
..._batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
+9
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
...batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
+10
-11
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_backward_xdl_cshuffle_v4.hpp
View file @
0413755a
...
@@ -86,8 +86,8 @@ __global__ void
...
@@ -86,8 +86,8 @@ __global__ void
const
float
p_drop
,
const
float
p_drop
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
raw_m_padded
,
const
index_t
NRaw
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -110,10 +110,11 @@ __global__ void
...
@@ -110,10 +110,11 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
ck
::
philox
ph
(
seed
,
0
,
offset
);
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
ZDataType
*
z_matrix_ptr
=
(
p_z_grid
==
nullptr
?
nullptr
:
p_z_grid
+
z_batch_offset
);
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
for
(
index_t
i
=
0
;
i
<
nblock
;
i
++
)
...
@@ -146,9 +147,8 @@ __global__ void
...
@@ -146,9 +147,8 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
z_random_matrix_offset
,
MRaw
,
raw_n_padded
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -181,9 +181,8 @@ __global__ void
...
@@ -181,9 +181,8 @@ __global__ void
c0_matrix_mask
,
c0_matrix_mask
,
p_drop
,
p_drop
,
ph
,
ph
,
g_idx
,
z_random_matrix_offset
,
MRaw
,
raw_n_padded
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_backward_xdl_cshuffle_pt6.hpp
View file @
0413755a
...
@@ -1254,9 +1254,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1254,9 +1254,8 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
const
C0MatrixMask
&
c0_matrix_mask
,
const
C0MatrixMask
&
c0_matrix_mask
,
const
float
p_drop
,
const
float
p_drop
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
z_random_matrix_offset
,
const
index_t
MRaw
,
const
index_t
raw_n_padded
,
const
index_t
NRaw
,
const
index_t
block_idx_n
)
const
index_t
block_idx_n
)
{
{
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
const
FloatGemmAcc
p_dropout
=
type_convert
<
FloatGemmAcc
>
(
1.0
f
-
p_drop
);
...
@@ -1959,16 +1958,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1959,16 +1958,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
n_global
;
// unique element global 1d id
auto
global_elem_id
=
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
NRaw
+
int
(
global_elem_id_raw
/
M4
)
*
M4
;
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
int
(
global_elem_id_raw
/
M4
)
*
M4
;
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
blockwise_dropout
.
template
ApplyDropoutAttnBwdSaveZ
<
decltype
(
s_slash_p_thread_buf
),
decltype
(
z_tenor_buffer
),
decltype
(
z_tenor_buffer
),
true
>(
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
NRaw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
z_tenor_buffer
,
raw_n_padded
);
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
z_thread_copy_vgpr_to_global
.
Run
(
z_thread_desc_m0_n0_m1_n1_m2_n2_m3_m4_m5_n3
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
...
@@ -1986,16 +1985,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
...
@@ -1986,16 +1985,16 @@ struct GridwiseBatchedMultiheadAttentionBackward_Xdl_CShuffle_V1
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id_raw
=
auto
global_elem_id_raw
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
n_global
;
// unique element global 1d id
auto
global_elem_id
=
auto
global_elem_id
=
(
global_elem_id_raw
%
M4
)
*
NRaw
+
int
(
global_elem_id_raw
/
M4
)
*
M4
;
(
global_elem_id_raw
%
M4
)
*
raw_n_padded
+
int
(
global_elem_id_raw
/
M4
)
*
M4
;
// P_dropped
// P_dropped
blockwise_dropout
blockwise_dropout
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
true
>(
.
template
ApplyDropoutAttnBwd
<
decltype
(
s_slash_p_thread_buf
),
true
>(
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
NRaw
);
s_slash_p_thread_buf
,
ph
,
global_elem_id
,
raw_n_padded
);
}
}
block_sync_lds
();
// wait for gemm1 LDS read
block_sync_lds
();
// wait for gemm1 LDS read
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment