Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0b0a70a5
Commit
0b0a70a5
authored
Apr 25, 2025
by
yuguo
Browse files
Merge branch 'main' of
http://10.6.10.68/dcutoolkit/deeplearing/TransformerEngine
parents
e80f260d
3ce226ae
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
3 deletions
+18
-3
transformer_engine/common/transpose/cast_transpose_fusion.cu
transformer_engine/common/transpose/cast_transpose_fusion.cu
+2
-1
transformer_engine/common/transpose/rtc/cast_transpose.cu
transformer_engine/common/transpose/rtc/cast_transpose.cu
+14
-1
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
...rmer_engine/common/transpose/rtc/cast_transpose_fusion.cu
+2
-1
No files found.
transformer_engine/common/transpose/cast_transpose_fusion.cu
View file @
0b0a70a5
...
...
@@ -171,7 +171,8 @@ inline __device__ void cast_and_transpose_regs(const CVec (&in)[nvec_out],
for
(
unsigned
int
j
=
0
;
j
<
nvec_in
;
++
j
)
{
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
#ifdef __HIP_PLATFORM_AMD__
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
,
THREADS_PER_WARP
);
// shuffle data in a warp
__syncthreads
();
#else
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#endif
...
...
transformer_engine/common/transpose/rtc/cast_transpose.cu
View file @
0b0a70a5
...
...
@@ -91,16 +91,25 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
local_output_c
.
store_to
(
&
output_c
[
row
*
row_length
+
col
]);
}
}
#ifndef __HIP_PLATFORM_AMD__
// Copy from registers to shared memory to global memory
__shared__
OVecT
shared_output_t
[
THREADS_PER_WARP
][
THREADS_PER_WARP
+
1
];
#else
constexpr
size_t
inner_dim
=
THREADS_PER_WARP
+
1
;
constexpr
size_t
outter_dim
=
THREADS_PER_WARP
;
__shared__
OVecT
shared_output_t
[
outter_dim
*
inner_dim
];
#endif
#pragma unroll
for
(
size_t
j2
=
0
;
j2
<
nvec_in
;
++
j2
)
{
#pragma unroll
for
(
size_t
iter
=
0
;
iter
<
num_iterations
;
++
iter
)
{
const
size_t
i1
=
tidy
+
iter
*
bdimy
;
const
size_t
j1
=
tidx
;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t
[
j1
][
i1
]
=
local_output_t
[
j2
][
iter
];
#else
shared_output_t
[
j1
*
inner_dim
+
i1
]
=
local_output_t
[
j2
][
iter
];
#endif
}
__syncthreads
();
#pragma unroll
...
...
@@ -109,7 +118,11 @@ __global__ void __launch_bounds__(block_size) cast_transpose_optimized_kernel(
const
size_t
j1
=
tidy
+
iter
*
bdimy
;
const
size_t
row
=
tile_row
+
i1
*
nvec_out
;
const
size_t
col
=
tile_col
+
j1
*
nvec_in
+
j2
;
#ifndef __HIP_PLATFORM_AMD__
shared_output_t
[
j1
][
i1
].
store_to
(
&
output_t
[
col
*
num_rows
+
row
]);
#else
shared_output_t
[
j1
*
inner_dim
+
i1
].
store_to
(
&
output_t
[
col
*
num_rows
+
row
]);
#endif
}
__syncthreads
();
}
...
...
transformer_engine/common/transpose/rtc/cast_transpose_fusion.cu
View file @
0b0a70a5
...
...
@@ -91,7 +91,8 @@ inline __device__ void cast_and_transpose_regs_optimized(const CVec (&in)[NVEC_O
for
(
unsigned
int
j
=
0
;
j
<
NVEC_IN
;
++
j
)
{
CType
elt
=
step_dbias
.
data
.
elt
[
j
];
#ifdef __HIP_PLATFORM_AMD__
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
elt
=
__shfl
(
elt
,
dbias_shfl_src_lane
,
THREADS_PER_WARP
);
// shuffle data in a warp
__syncthreads
();
#else
elt
=
__shfl_sync
(
0xffffffff
,
elt
,
dbias_shfl_src_lane
);
// shuffle data in a warp
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment