Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
b65e50ba
Commit
b65e50ba
authored
May 21, 2025
by
yuguo
Browse files
[DCU] fix merge
parent
f8c2af4c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
23 deletions
+37
-23
transformer_engine/common/fused_attn/context_parallel.cu
transformer_engine/common/fused_attn/context_parallel.cu
+1
-1
transformer_engine/common/multi_tensor/l2norm.cu
transformer_engine/common/multi_tensor/l2norm.cu
+2
-2
transformer_engine/common/recipe/fp8_block_scaling.cu
transformer_engine/common/recipe/fp8_block_scaling.cu
+12
-0
transformer_engine/pytorch/attention/dot_product_attention/backends.py
...ngine/pytorch/attention/dot_product_attention/backends.py
+22
-20
No files found.
transformer_engine/common/fused_attn/context_parallel.cu
View file @
b65e50ba
...
@@ -189,7 +189,7 @@ __global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int
...
@@ -189,7 +189,7 @@ __global__ void thd_lse_kernel(float *lse, float *half_lse, int *cu_seqlens, int
**************************************************************************************************/
**************************************************************************************************/
template
<
typename
dtype
,
int
only_second_half
,
int
tile_size
,
bool
lse_packed
>
template
<
typename
dtype
,
int
only_second_half
,
int
tile_size
,
bool
lse_packed
>
__global__
void
thd_out_correction_kernel
(
dtype
*
out
,
dtype
*
out_per_step
,
float
*
lse
,
__global__
void
__launch_bounds__
(
512
)
thd_out_correction_kernel
(
dtype
*
out
,
dtype
*
out_per_step
,
float
*
lse
,
float
*
lse_per_step
,
int
*
cu_seqlens
,
int
batch
,
float
*
lse_per_step
,
int
*
cu_seqlens
,
int
batch
,
int
num_heads
,
int
dim_per_head
,
int
lse_seqlen
,
int
num_heads
,
int
dim_per_head
,
int
lse_seqlen
,
int
lse_per_step_seqlen
)
{
int
lse_per_step_seqlen
)
{
...
...
transformer_engine/common/multi_tensor/l2norm.cu
View file @
b65e50ba
...
@@ -57,7 +57,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
...
@@ -57,7 +57,7 @@ reduce_block_into_lanes(T *x, T val, int lanes = 1,
// __SYNCWARP();
// __SYNCWARP();
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
#pragma unroll
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
);
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
);
#else
#else
#pragma unroll
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
...
@@ -104,7 +104,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
...
@@ -104,7 +104,7 @@ reduce_block_into_lanes_max_op(T *x, T val, int lanes = 1,
#pragma unroll
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
)));
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down
(
final
,
i
,
THREADS_PER_WARP
)));
#else
#else
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
#endif
#endif
...
...
transformer_engine/common/recipe/fp8_block_scaling.cu
View file @
b65e50ba
...
@@ -52,7 +52,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -52,7 +52,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
float
other_amax
=
__shfl_down
(
amax
,
delta
);
#else
float
other_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
amax
,
delta
);
float
other_amax
=
__shfl_down_sync
(
0xFFFFFFFF
,
amax
,
delta
);
#endif
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
__builtin_assume
(
other_amax
>=
0
);
amax
=
fmaxf
(
amax
,
other_amax
);
amax
=
fmaxf
(
amax
,
other_amax
);
...
@@ -119,10 +123,18 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
...
@@ -119,10 +123,18 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}
}
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
for
(
int
delta
=
kThreadsPerWarp
/
2
;
delta
>
0
;
delta
/=
2
)
{
#ifdef __HIP_PLATFORM_AMD__
bool
other_skip_store
=
__shfl_down
(
skip_store
,
delta
);
#else
bool
other_skip_store
=
__shfl_down_sync
(
0xFFFFFFFF
,
skip_store
,
delta
);
bool
other_skip_store
=
__shfl_down_sync
(
0xFFFFFFFF
,
skip_store
,
delta
);
#endif
skip_store
=
skip_store
&&
other_skip_store
;
skip_store
=
skip_store
&&
other_skip_store
;
}
}
#ifdef __HIP_PLATFORM_AMD__
skip_store
=
__shfl
(
skip_store
,
0
);
#else
skip_store
=
__shfl_sync
(
0xFFFFFFFF
,
skip_store
,
0
);
skip_store
=
__shfl_sync
(
0xFFFFFFFF
,
skip_store
,
0
);
#endif
if
(
skip_store
)
{
if
(
skip_store
)
{
return
;
return
;
}
}
...
...
transformer_engine/pytorch/attention/dot_product_attention/backends.py
View file @
b65e50ba
...
@@ -13,6 +13,7 @@ import logging
...
@@ -13,6 +13,7 @@ import logging
from
packaging.version
import
Version
as
PkgVersion
from
packaging.version
import
Version
as
PkgVersion
import
torch
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.utils
import
(
from
transformer_engine.pytorch.utils
import
(
SplitAlongDim
,
SplitAlongDim
,
...
@@ -92,7 +93,7 @@ else:
...
@@ -92,7 +93,7 @@ else:
fa_utils
.
set_flash_attention_version
()
fa_utils
.
set_flash_attention_version
()
elif
(
elif
(
torch
.
cuda
.
is_available
()
torch
.
cuda
.
is_available
()
and
get_device_compute_capability
()
>=
(
8
,
0
)
and
(
IS_HIP_EXTENSION
or
get_device_compute_capability
()
>=
(
8
,
0
)
)
and
dpa_utils
.
_NVTE_FLASH_ATTN
and
dpa_utils
.
_NVTE_FLASH_ATTN
):
):
attn_log
.
fa_logger
.
warning
(
attn_log
.
fa_logger
.
warning
(
...
@@ -107,14 +108,15 @@ else:
...
@@ -107,14 +108,15 @@ else:
),
),
fa_utils
.
version
,
fa_utils
.
version
,
)
)
try
:
if
not
IS_HIP_EXTENSION
:
try
:
fa_utils
.
fa3_version
=
PkgVersion
(
get_pkg_version
(
"flash-attn-3"
))
fa_utils
.
fa3_version
=
PkgVersion
(
get_pkg_version
(
"flash-attn-3"
))
except
PackageNotFoundError
:
except
PackageNotFoundError
:
flash_attn_func_v3
=
None
flash_attn_func_v3
=
None
flash_attn_varlen_func_v3
=
None
flash_attn_varlen_func_v3
=
None
flash_attn_with_kvcache_v3
=
None
flash_attn_with_kvcache_v3
=
None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else
:
else
:
from
flash_attn_3.flash_attn_interface
import
flash_attn_func
as
flash_attn_func_v3
from
flash_attn_3.flash_attn_interface
import
flash_attn_func
as
flash_attn_func_v3
from
flash_attn_3.flash_attn_interface
import
(
from
flash_attn_3.flash_attn_interface
import
(
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
,
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment