Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
2b05e121
Commit
2b05e121
authored
Jun 17, 2025
by
yuguo
Browse files
Merge commit '
a69692ac
' of...
Merge commit '
a69692ac
' of
https://github.com/NVIDIA/TransformerEngine
parents
0fd441c2
a69692ac
Changes
245
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
152 additions
and
73 deletions
+152
-73
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+9
-2
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+7
-2
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+16
-0
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+111
-69
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+9
-0
No files found.
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
2b05e121
...
@@ -4,13 +4,14 @@
...
@@ -4,13 +4,14 @@
"""Tensor class with FP8 data"""
"""Tensor class with FP8 data"""
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
import
warnings
import
warnings
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8CurrentScaling
,
Recipe
from
..utils
import
canonicalize_process_group
,
devices_match
from
..utils
import
canonicalize_process_group
,
devices_match
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
...
@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
...
@@ -166,6 +167,9 @@ class Float8Quantizer(Quantizer):
quantizer
=
self
,
quantizer
=
self
,
)
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
DelayedScaling
class
Float8CurrentScalingQuantizer
(
Quantizer
):
class
Float8CurrentScalingQuantizer
(
Quantizer
):
"""Builder class for FP8 tensors with per-tensor current scaling
"""Builder class for FP8 tensors with per-tensor current scaling
...
@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -328,6 +332,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
"""Get process group for amax reduction"""
"""Get process group for amax reduction"""
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
return
canonicalize_process_group
(
self
.
amax_reduction_group
)
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
Float8CurrentScaling
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
class
Float8Tensor
(
Float8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
"""Experimental tensor class with FP8 data
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
2b05e121
...
@@ -6,12 +6,13 @@
...
@@ -6,12 +6,13 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
import
math
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
torch
import
torch
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
MXFP8BlockScaling
,
Recipe
from
..constants
import
MXFP8_BLOCK_SCALING_SIZE
from
..constants
import
MXFP8_BLOCK_SCALING_SIZE
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
from
..utils
import
devices_match
,
round_up_to_nearest_multiple
...
@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
...
@@ -135,6 +136,9 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
# TODO(ksivamani): No calibration needed for mxfp8?
pass
pass
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
return
MXFP8BlockScaling
class
MXFP8Tensor
(
MXFP8TensorBase
,
QuantizedTensor
):
class
MXFP8Tensor
(
MXFP8TensorBase
,
QuantizedTensor
):
"""Experimental tensor class with FP8 data
"""Experimental tensor class with FP8 data
...
@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
...
@@ -380,6 +384,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# Quantize to FP8
# Quantize to FP8
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
assert
self
.
_quantizer
is
not
None
,
"Can't quantize without a quantizer"
self
.
_quantizer
.
internal
=
False
self
.
data
=
self
.
_quantizer
.
quantize
(
tensor
)
self
.
data
=
self
.
_quantizer
.
quantize
(
tensor
)
if
self
.
requires_grad
!=
tensor
.
requires_grad
:
if
self
.
requires_grad
!=
tensor
.
requires_grad
:
self
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
self
.
requires_grad_
(
requires_grad
=
tensor
.
requires_grad
)
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
2b05e121
...
@@ -8,11 +8,13 @@ from __future__ import annotations
...
@@ -8,11 +8,13 @@ from __future__ import annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
from
typing
import
Optional
,
Tuple
,
Iterable
,
Any
,
Dict
,
Union
import
abc
import
abc
import
copy
import
copy
import
warnings
import
torch
import
torch
from
torch.utils._pytree
import
tree_map
from
torch.utils._pytree
import
tree_map
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine.common.recipe
import
Recipe
class
QuantizedTensorBase
:
class
QuantizedTensorBase
:
...
@@ -31,6 +33,8 @@ class QuantizedTensorBase:
...
@@ -31,6 +33,8 @@ class QuantizedTensorBase:
XTensor should only implement the functionality needed
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
_quantizer
:
Optional
[
Quantizer
]
def
update_usage
(
def
update_usage
(
self
,
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
...
@@ -69,6 +73,14 @@ class QuantizedTensorBase:
...
@@ -69,6 +73,14 @@ class QuantizedTensorBase:
f
"
{
self
.
__class__
.
__name__
}
class does not implement restore_from_saved function"
f
"
{
self
.
__class__
.
__name__
}
class does not implement restore_from_saved function"
)
)
def
update_quantizer
(
self
,
quantizer
:
Quantizer
):
"""Update quantizer for the tensor"""
if
self
.
_quantizer
is
None
:
raise
RuntimeError
(
"To be updated, quantizer must be set"
)
if
self
.
_quantizer
is
not
quantizer
:
warnings
.
warn
(
"Quantizer is being updated, this may affect model behavior"
)
self
.
_quantizer
=
quantizer
def
prepare_for_saving
(
def
prepare_for_saving
(
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorBase
],
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorBase
],
...
@@ -238,6 +250,10 @@ class Quantizer(abc.ABC):
...
@@ -238,6 +250,10 @@ class Quantizer(abc.ABC):
"""Create shallow copy"""
"""Create shallow copy"""
return
copy
.
copy
(
self
)
return
copy
.
copy
(
self
)
@
abc
.
abstractmethod
def
_get_compatible_recipe
(
self
)
->
Union
[
type
[
Recipe
],
None
]:
"""Returns recipe class that is compatible with this quantizer"""
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
class
_QuantizeFunc
(
torch
.
autograd
.
Function
):
"""Cast to FP8 from other dtype"""
"""Cast to FP8 from other dtype"""
...
...
transformer_engine/pytorch/transformer.py
View file @
2b05e121
...
@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
...
@@ -179,12 +179,12 @@ class TransformerLayer(torch.nn.Module):
The device on which the parameters of the model will be allocated. It is the user's
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
responsibility to ensure all parameters are moved to the GPU before running the
forward pass.
forward pass.
attn_input_format: {'sbhd', 'bshd'}, default = 'sbhd'
attn_input_format: {'sbhd', 'bshd'
, 'thd'
}, default = 'sbhd'
This controls whether the dimensions of the
This controls whether the dimensions of the
intermediate hidden states is 'batch first' ('bshd')
or
intermediate hidden states is
'sequence first' ('sbhd'),
'batch first' ('bshd')
,
'sequence
first' ('
sb
hd'). `s` stands for the sequence
or 'token
first' ('
t
hd'). `s` stands for the sequence
length, `b` batch size,
length, `b` batch size
, `h` the number of heads, `d`
`t` the total number of tokens
, `h` the number of heads, `d`
head size.
head size.
Note that these formats are very closely
Note that these formats are very closely
related to the `qkv_format` in the `MultiHeadAttention`
related to the `qkv_format` in the `MultiHeadAttention`
and `DotProductAttention` modules.
and `DotProductAttention` modules.
name: str, default = `None`
name: str, default = `None`
...
@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
...
@@ -235,6 +235,14 @@ class TransformerLayer(torch.nn.Module):
parameter for query-key-value. This enables optimizations such as QKV
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits and also enables the argument
fusion without concatentations/splits and also enables the argument
`fuse_wgrad_accumulation`.
`fuse_wgrad_accumulation`.
use_qk_norm: bool, default = 'False'
if set to `True`, L2 normalization is applied to query and key tensors
after RoPE (if applicable) but before attention computation.
This follows the Llama4 approach for QK normalization to improve
training stability and model performance.
qk_norm_eps: float, default = 1e-6
epsilon value for L2 normalization of query and key tensors.
Only used when `use_qk_norm` is True.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -284,6 +292,8 @@ class TransformerLayer(torch.nn.Module):
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
device
:
Union
[
torch
.
device
,
str
]
=
"cuda"
,
attn_input_format
:
str
=
"sbhd"
,
attn_input_format
:
str
=
"sbhd"
,
name
:
str
=
None
,
name
:
str
=
None
,
use_qk_norm
:
bool
=
False
,
qk_norm_eps
:
float
=
1e-6
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -373,6 +383,8 @@ class TransformerLayer(torch.nn.Module):
"ub_overlap_rs"
:
ub_overlap_rs
,
"ub_overlap_rs"
:
ub_overlap_rs
,
"ub_overlap_rs_dgrad"
:
ub_overlap_rs_dgrad
,
"ub_overlap_rs_dgrad"
:
ub_overlap_rs_dgrad
,
"qkv_format"
:
self
.
attn_input_format
,
"qkv_format"
:
self
.
attn_input_format
,
"seq_length"
:
seq_length
,
"micro_batch_size"
:
micro_batch_size
,
}
}
self
.
self_attention
=
MultiheadAttention
(
self
.
self_attention
=
MultiheadAttention
(
...
@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -384,6 +396,8 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
not
self
.
parallel_attention_mlp
,
return_bias
=
not
self
.
parallel_attention_mlp
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
device
,
device
=
device
,
use_qk_norm
=
use_qk_norm
,
qk_norm_eps
=
qk_norm_eps
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
name
=
name
+
".self_attention"
if
name
is
not
None
else
None
,
)
)
...
@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -398,6 +412,8 @@ class TransformerLayer(torch.nn.Module):
return_bias
=
True
,
return_bias
=
True
,
normalization
=
normalization
,
normalization
=
normalization
,
device
=
device
,
device
=
device
,
use_qk_norm
=
use_qk_norm
,
qk_norm_eps
=
qk_norm_eps
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
name
=
name
+
".inter_attention"
if
name
is
not
None
else
None
,
)
)
...
@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -552,6 +568,8 @@ class TransformerLayer(torch.nn.Module):
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_q_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
cu_seqlens_kv_padded
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_q
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
max_seqlen_kv
:
Optional
[
int
]
=
None
,
fast_zero_fill
:
bool
=
True
,
fast_zero_fill
:
bool
=
True
,
...
@@ -633,15 +651,25 @@ class TransformerLayer(torch.nn.Module):
...
@@ -633,15 +651,25 @@ class TransformerLayer(torch.nn.Module):
cu_seqlens_q: Optional[torch.Tensor], default = `None`
cu_seqlens_q: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
Cumulative sum of sequence lengths (without offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32.
with shape [batch_size + 1] and dtype torch.int32.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
cu_seqlens_kv: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
Cumulative sum of sequence lengths (without offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Used by decoders' cross-attention.
cu_seqlens_q_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `query_layer`,
with shape [batch_size + 1] and dtype torch.int32. Set to `cu_seqlens_q` if None.
Used by encoders, or decoders' self-attention.
cu_seqlens_kv_padded: Optional[torch.Tensor], default = `None`
Cumulative sum of sequence lengths (with offset) in a batch for `key_layer`
and `value_layer`, with shape [batch_size + 1] and dtype torch.int32.
Set to `cu_seqlens_kv` if None. Used by decoders' cross-attention.
max_seqlen_q: Optional[int], default = `None`
max_seqlen_q: Optional[int], default = `None`
Maximum sequence length in `query_layer`.
Maximum sequence length in `query_layer`.
Calculated from `cu_seqlens_q` if not provided.
Calculated from `cu_seqlens_q
_padded
` if not provided.
max_seqlen_kv: Optional[int], default = `None`
max_seqlen_kv: Optional[int], default = `None`
Maximum sequence length in `key_layer` and `value_layer`.
Maximum sequence length in `key_layer` and `value_layer`.
Calculated from `cu_seqlens_kv` if not provided.
Calculated from `cu_seqlens_kv
_padded
` if not provided.
fast_zero_fill: bool, default = `True`
fast_zero_fill: bool, default = `True`
Whether to set output tensors to 0 or not before use.
Whether to set output tensors to 0 or not before use.
inference_params: InferenceParams, default = None
inference_params: InferenceParams, default = None
...
@@ -649,7 +677,8 @@ class TransformerLayer(torch.nn.Module):
...
@@ -649,7 +677,8 @@ class TransformerLayer(torch.nn.Module):
to efficiently calculate and store the context during inference.
to efficiently calculate and store the context during inference.
pad_between_seqs: Optional[bool], default = `None`
pad_between_seqs: Optional[bool], default = `None`
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded.
If true, there are padding tokens between individual sequences in a packed batch.
If true, there are padding tokens between individual sequences in a packed batch,
i.e. qkv_format = 'thd'.
"""
"""
if
self_attn_mask_type
is
None
:
if
self_attn_mask_type
is
None
:
...
@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module):
...
@@ -678,7 +707,9 @@ class TransformerLayer(torch.nn.Module):
if
(
if
(
"padding"
in
self_attn_mask_type
or
self_attn_mask_type
==
"arbitrary"
"padding"
in
self_attn_mask_type
or
self_attn_mask_type
==
"arbitrary"
)
and
attention_mask
is
not
None
:
)
and
attention_mask
is
not
None
:
assert
attention_mask
.
dtype
==
torch
.
bool
,
"Attention mask must be a boolean tensor"
assert
all
(
attention_mask
[
i
].
dtype
==
torch
.
bool
for
i
in
range
(
len
(
attention_mask
))
),
"Attention mask must be a boolean tensor or a list/tuple of two boolean tensors"
if
(
if
(
"padding"
in
enc_dec_attn_mask_type
or
enc_dec_attn_mask_type
==
"arbitrary"
"padding"
in
enc_dec_attn_mask_type
or
enc_dec_attn_mask_type
==
"arbitrary"
)
and
enc_dec_attn_mask
is
not
None
:
)
and
enc_dec_attn_mask
is
not
None
:
...
@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module):
...
@@ -707,9 +738,11 @@ class TransformerLayer(torch.nn.Module):
core_attention_bias
=
core_attention_bias
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_kv
=
cu_seqlens_q
,
cu_seqlens_q_padded
=
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_q_padded
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_
kv
,
max_seqlen_kv
=
max_seqlen_
q
,
fast_zero_fill
=
fast_zero_fill
,
fast_zero_fill
=
fast_zero_fill
,
pad_between_seqs
=
pad_between_seqs
,
pad_between_seqs
=
pad_between_seqs
,
)
)
...
@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module):
...
@@ -733,12 +766,21 @@ class TransformerLayer(torch.nn.Module):
attn_mask_type
=
enc_dec_attn_mask_type
,
attn_mask_type
=
enc_dec_attn_mask_type
,
window_size
=
enc_dec_window_size
,
window_size
=
enc_dec_window_size
,
encoder_output
=
encoder_output
,
encoder_output
=
encoder_output
,
inference_params
=
inference_params
,
is_first_microbatch
=
is_first_microbatch
,
is_first_microbatch
=
is_first_microbatch
,
checkpoint_core_attention
=
checkpoint_core_attention
,
checkpoint_core_attention
=
checkpoint_core_attention
,
rotary_pos_emb
=
rotary_pos_emb
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias_type
=
core_attention_bias_type
,
core_attention_bias
=
core_attention_bias
,
core_attention_bias
=
core_attention_bias
,
alibi_slopes
=
alibi_slopes
,
alibi_slopes
=
alibi_slopes
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_kv
,
cu_seqlens_q_padded
=
cu_seqlens_q_padded
,
cu_seqlens_kv_padded
=
cu_seqlens_kv_padded
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_kv
=
max_seqlen_kv
,
fast_zero_fill
=
fast_zero_fill
,
fast_zero_fill
=
fast_zero_fill
,
pad_between_seqs
=
pad_between_seqs
,
)
)
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
attention_output
,
attention_bias
,
residual
=
inter_attention_outputs
attention_output
,
attention_bias
,
residual
=
inter_attention_outputs
...
...
transformer_engine/pytorch/utils.py
View file @
2b05e121
...
@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
...
@@ -37,8 +37,16 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
Must be used carefully.
Must be used carefully.
"""
"""
for
t
in
tensors
:
for
t
in
tensors
:
if
t
is
not
None
:
if
t
is
not
None
:
# Workaround for double buffering in cpu offload
if
hasattr
(
t
,
"do_not_clear"
):
continue
if
hasattr
(
t
,
"get_data_tensors"
):
if
any
(
hasattr
(
tensor
,
"do_not_clear"
)
for
tensor
in
t
.
get_data_tensors
()):
continue
if
hasattr
(
t
,
"clear"
):
if
hasattr
(
t
,
"clear"
):
t
.
clear
()
t
.
clear
()
else
:
else
:
...
@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None:
...
@@ -462,6 +470,7 @@ def is_bf16_compatible() -> None:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
@
functools
.
lru_cache
(
maxsize
=
None
)
def
is_non_tn_fp8_gemm_supported
()
->
bool
:
def
is_non_tn_fp8_gemm_supported
()
->
bool
:
"""Checks whether the device supports
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
non-TN layouts for FP8 GEMMs.
...
...
Prev
1
…
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment