Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1d7099b6
Commit
1d7099b6
authored
Jul 19, 2024
by
danyao12
Browse files
fix hd256 dropout scratch
parent
a67bdd63
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
39 additions
and
7 deletions
+39
-7
include/ck_tile/core/numeric/vector_type.hpp
include/ck_tile/core/numeric/vector_type.hpp
+9
-0
include/ck_tile/core/utility/philox_rand.hpp
include/ck_tile/core/utility/philox_rand.hpp
+16
-0
include/ck_tile/ops/fmha/block/block_dropout.hpp
include/ck_tile/ops/fmha/block/block_dropout.hpp
+14
-7
No files found.
include/ck_tile/core/numeric/vector_type.hpp
View file @
1d7099b6
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
...
@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x32_t
=
int32_t
__attribute__
((
ext_vector_type
(
32
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
using
int32x64_t
=
int32_t
__attribute__
((
ext_vector_type
(
64
)));
// u32
// using uint32_t = ...
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
using
uint32x4_t
=
uint32_t
__attribute__
((
ext_vector_type
(
4
)));
using
uint32x8_t
=
uint32_t
__attribute__
((
ext_vector_type
(
8
)));
using
uint32x16_t
=
uint32_t
__attribute__
((
ext_vector_type
(
16
)));
using
uint32x32_t
=
uint32_t
__attribute__
((
ext_vector_type
(
32
)));
using
uint32x64_t
=
uint32_t
__attribute__
((
ext_vector_type
(
64
)));
// i16
// i16
// using int16_t = ...
// using int16_t = ...
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
using
int16x2_t
=
int16_t
__attribute__
((
ext_vector_type
(
2
)));
...
...
include/ck_tile/core/utility/philox_rand.hpp
View file @
1d7099b6
...
@@ -53,6 +53,22 @@ class philox
...
@@ -53,6 +53,22 @@ class philox
out_tmp
[
3
]
=
tmp_ph
.
w
;
out_tmp
[
3
]
=
tmp_ph
.
w
;
}
}
CK_TILE_HOST_DEVICE
void
get_random_4x8
(
uint8_t
*
out
,
const
unsigned
long
long
subsequence
,
const
index_t
start_idx
)
const
{
uint4
tmp_ph
;
tmp_ph
=
get_philox_4x32
(
subsequence
);
uint32x4_t
tmp
;
tmp
[
0
]
=
tmp_ph
.
x
;
tmp
[
1
]
=
tmp_ph
.
y
;
tmp
[
2
]
=
tmp_ph
.
z
;
tmp
[
3
]
=
tmp_ph
.
w
;
uint32_t
*
out_tmp
=
reinterpret_cast
<
uint32_t
*>
(
&
out
[
0
]);
out_tmp
[
0
]
=
tmp
[
start_idx
];
}
private:
private:
struct
ull2
struct
ull2
{
{
...
...
include/ck_tile/ops/fmha/block/block_dropout.hpp
View file @
1d7099b6
...
@@ -333,19 +333,26 @@ struct BlockDropout
...
@@ -333,19 +333,26 @@ struct BlockDropout
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
uint2
rowcol
=
make_uint2
(
block_row_start
,
block_col_start
);
// generate random number
// generate random number
uint8_t
random_uint8_t
[
16
];
uint8_t
*
random_uint8_t_
;
uint8_t
*
random_uint8_t_
=
random_uint8_t
;
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
if
constexpr
(
!
IsWG32
)
if
constexpr
(
!
IsWG32
)
{
{
uint8_t
random_uint8_t
[
4
];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
// m1t16~m1t31/m1t48~m1t63: 3
int
start_idx
=
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
const
index_t
start_idx
=
uint32_t
*
random_uint32_t
=
reinterpret_cast
<
uint32_t
*>
(
random_uint8_t
);
((
get_lane_id
()
>>
4
)
&
1
)
+
(((
start_m0_idx
>>
4
)
&
1
)
<<
1
);
random_uint8_t_
=
reinterpret_cast
<
uint8_t
*>
(
&
random_uint32_t
[
start_idx
]);
ph
.
get_random_4x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
),
start_idx
);
random_uint8_t_
=
random_uint8_t
;
}
else
{
uint8_t
random_uint8_t
[
16
];
ph
.
get_random_16x8
(
random_uint8_t
,
reinterpret_cast
<
unsigned
long
long
&>
(
rowcol
));
random_uint8_t_
=
random_uint8_t
;
}
}
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
constexpr
auto
randval_spans
=
decltype
(
randval
)
::
get_distributed_spans
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment