Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
092a88bf
Commit
092a88bf
authored
Jun 16, 2023
by
guangzlu
Browse files
modified z offset into fwd
parent
adb8aaaa
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
15 deletions
+14
-15
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
...e_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
+10
-10
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
..._batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
+4
-5
No files found.
include/ck/tensor_operation/gpu/device/impl/device_batched_multihead_attention_forward_xdl_cshuffle_v2.hpp
View file @
092a88bf
...
@@ -80,8 +80,8 @@ __global__ void
...
@@ -80,8 +80,8 @@ __global__ void
const
GemmAccDataType
p_dropout_rescale
,
const
GemmAccDataType
p_dropout_rescale
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
seed
,
const
unsigned
long
long
offset
,
const
unsigned
long
long
offset
,
const
index_t
MRaw
,
const
index_t
raw_m_padded
,
const
index_t
NRaw
)
const
index_t
raw_n_padded
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -102,8 +102,10 @@ __global__ void
...
@@ -102,8 +102,10 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetLSEBasePtr
(
g_idx
)));
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
// const index_t global_thread_id = get_thread_global_1d_id();
ck
::
philox
ph
(
seed
,
global_thread_id
,
offset
);
ck
::
philox
ph
(
seed
,
0
,
offset
);
const
index_t
z_random_matrix_offset
=
g_idx
*
raw_m_padded
*
raw_n_padded
;
if
constexpr
(
Deterministic
)
if
constexpr
(
Deterministic
)
{
{
...
@@ -133,9 +135,8 @@ __global__ void
...
@@ -133,9 +135,8 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
g_idx
,
z_random_matrix_offset
,
MRaw
,
raw_n_padded
,
NRaw
,
i
);
i
);
}
}
}
}
...
@@ -165,9 +166,8 @@ __global__ void
...
@@ -165,9 +166,8 @@ __global__ void
p_dropout_in_16bits
,
p_dropout_in_16bits
,
p_dropout_rescale
,
p_dropout_rescale
,
ph
,
ph
,
g_idx
,
z_random_matrix_offset
,
MRaw
,
raw_n_padded
,
NRaw
,
0
);
0
);
}
}
#else
#else
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_multihead_attention_forward_xdl_cshuffle_pt2.hpp
View file @
092a88bf
...
@@ -463,9 +463,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -463,9 +463,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
const
ushort
p_dropout_in_16bits
,
const
ushort
p_dropout_in_16bits
,
FloatGemmAcc
p_dropout_rescale
,
FloatGemmAcc
p_dropout_rescale
,
ck
::
philox
&
ph
,
ck
::
philox
&
ph
,
const
index_t
g_idx
,
const
index_t
z_random_matrix_offset
,
const
index_t
MRaw
,
const
index_t
raw_n_padded
,
const
index_t
NRaw
,
const
index_t
block_idx_m
)
const
index_t
block_idx_m
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
@@ -1197,8 +1196,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
...
@@ -1197,8 +1196,8 @@ struct GridwiseBatchedMultiheadAttentionForward_Xdl_CShuffle
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
m_global
=
m_local
+
m_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
n_global
=
n_local
+
n_block_data_idx_on_grid
;
auto
global_elem_id
=
auto
global_elem_id
=
z_random_matrix_offset
+
m_global
*
raw_n_padded
+
MRaw
*
NRaw
*
g_idx
+
m_global
*
NRaw
+
n_global
;
// unique element global 1d id
n_global
;
// unique element global 1d id
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
// index_t id_step = Acc0TileIterator::GetNumOfAccess() / n0.value;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment