Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
707 additions
and
26 deletions
+707
-26
transformer_engine/common/util/multi_stream.h
transformer_engine/common/util/multi_stream.h
+1
-1
transformer_engine/common/util/padding.cu
transformer_engine/common/util/padding.cu
+6
-6
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+681
-1
transformer_engine/common/util/pybind_helper.h
transformer_engine/common/util/pybind_helper.h
+3
-2
transformer_engine/common/util/rtc.cpp
transformer_engine/common/util/rtc.cpp
+1
-1
transformer_engine/common/util/rtc.h
transformer_engine/common/util/rtc.h
+1
-1
transformer_engine/common/util/shared_lib_wrapper.h
transformer_engine/common/util/shared_lib_wrapper.h
+1
-1
transformer_engine/common/util/string.h
transformer_engine/common/util/string.h
+1
-1
transformer_engine/common/util/string_header.h.in
transformer_engine/common/util/string_header.h.in
+1
-1
transformer_engine/common/util/system.h
transformer_engine/common/util/system.h
+1
-1
transformer_engine/common/util/vectorized_pointwise.h
transformer_engine/common/util/vectorized_pointwise.h
+1
-1
transformer_engine/common/utils.cuh
transformer_engine/common/utils.cuh
+1
-1
transformer_engine/common/utils.py
transformer_engine/common/utils.py
+1
-1
transformer_engine/debug/__init__.py
transformer_engine/debug/__init__.py
+1
-1
transformer_engine/debug/features/__init__.py
transformer_engine/debug/features/__init__.py
+1
-1
transformer_engine/debug/features/_test_dummy_feature.py
transformer_engine/debug/features/_test_dummy_feature.py
+1
-1
transformer_engine/debug/features/api.py
transformer_engine/debug/features/api.py
+1
-1
transformer_engine/debug/features/disable_fp8_gemm.py
transformer_engine/debug/features/disable_fp8_gemm.py
+1
-1
transformer_engine/debug/features/disable_fp8_layer.py
transformer_engine/debug/features/disable_fp8_layer.py
+1
-1
transformer_engine/debug/features/fake_quant.py
transformer_engine/debug/features/fake_quant.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/common/util/multi_stream.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/padding.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -101,7 +101,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
local_input
.
data
.
elt
[
j2
]
=
input
[
row
*
row_length
+
col
+
j2
];
local_input
.
data
.
elt
[
j2
]
=
input
[
static_cast
<
size_t
>
(
row
)
*
row_length
+
col
+
j2
];
}
}
}
...
...
@@ -112,14 +112,14 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
output
[
row
*
row_length
+
col
+
j2
]
=
local_output
.
data
.
elt
[
j2
];
output
[
static_cast
<
size_t
>
(
row
)
*
row_length
+
col
+
j2
]
=
local_output
.
data
.
elt
[
j2
];
}
}
}
else
if
(
row
<
padded_num_rows
)
{
// padding
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
output
[
row
*
row_length
+
col
+
j2
]
=
local_zero
;
output
[
static_cast
<
size_t
>
(
row
)
*
row_length
+
col
+
j2
]
=
local_zero
;
}
}
}
...
...
@@ -185,7 +185,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
local_input
.
data
.
elt
[
j2
]
=
input
[
row
*
row_length
+
col
+
j2
];
local_input
.
data
.
elt
[
j2
]
=
input
[
static_cast
<
size_t
>
(
row
)
*
row_length
+
col
+
j2
];
}
}
}
...
...
@@ -196,7 +196,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult
if
(
row
<
num_rows
)
{
for
(
int
j2
=
0
;
j2
<
nvec
;
++
j2
)
{
if
(
col
+
j2
<
row_length
)
{
output
[
row
*
row_length
+
col
+
j2
]
=
local_output
.
data
.
elt
[
j2
];
output
[
static_cast
<
size_t
>
(
row
)
*
row_length
+
col
+
j2
]
=
local_output
.
data
.
elt
[
j2
];
}
}
}
...
...
transformer_engine/common/util/ptx.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -840,8 +840,688 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890)
}
__device__
__forceinline__
int32_t
elect_one_sync
(
uint32_t
mask
=
0xFFFFFFFFu
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
int32_t
pred
=
0
;
asm
volatile
(
"{
\n\t
"
".reg .pred %px;
\n
"
"elect.sync _|%px, %1;
\n
"
"selp.b32 %0, 1, 0, %px;
\n
"
"
\n\t
}"
:
"=r"
(
pred
)
:
"r"
(
mask
));
return
pred
;
#else
NVTE_DEVICE_ERROR
(
"elect_one_sync is only supported on SM 10.0+."
);
return
0
;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
numbered_barrier_sync
(
uint32_t
num_threads
,
uint32_t
barrier_id
=
1u
)
{
asm
volatile
(
"bar.sync %0, %1;
\n
"
::
"r"
(
barrier_id
),
"r"
(
num_threads
));
}
__device__
__forceinline__
void
fma_f32_f16
(
float
&
out
,
uint16_t
const
&
a
,
uint16_t
const
&
b
,
float
const
&
c
=
0.0
f
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
"fma.rn.f32.f16 %0, %1, %2, %3;"
:
"=f"
(
out
)
:
"h"
(
a
),
"h"
(
b
),
"f"
(
c
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"fma_f32_f16 is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
fma_f32_bf16
(
float
&
out
,
uint16_t
const
&
a
,
uint16_t
const
&
b
,
float
const
&
c
=
0.0
f
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
asm
volatile
(
"fma.rn.f32.bf16 %0, %1, %2, %3;"
:
"=f"
(
out
)
:
"h"
(
a
),
"h"
(
b
),
"f"
(
c
)
:
"memory"
);
#else
NVTE_DEVICE_ERROR
(
"fma_f32_bf16 is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
reduce_sync_max_abs_f32
(
float
&
out
,
float
const
&
in
)
{
constexpr
bool
is_sm_100f
=
NVTE_CUDA_ARCH_MATCHES
(
ptx
::
FamilySpecific
<
100
>
);
if
constexpr
(
is_sm_100f
)
{
asm
volatile
(
"redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;"
:
"=f"
(
out
)
:
"f"
(
in
));
}
else
{
asm
volatile
(
"{
\n\t
"
".reg.b32 val;
\n
"
"abs.f32 val, %1;
\n
"
"redux.sync.max.u32 %0, val, 0xFFFFFFFF;
\n
"
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"f"
(
in
));
}
}
__device__
__forceinline__
bf16
get_amax
(
bf16
a
,
bf16
b
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
bf16
r
;
asm
volatile
(
"max.xorsign.abs.bf16 %0, %1, %2;"
:
"=h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
r
))
:
"h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
a
)),
"h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
b
)));
return
r
;
#else
NVTE_DEVICE_ERROR
(
"get_amax is only supported on SM 10.0+."
);
return
0.
f
;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
fp16
get_amax
(
fp16
a
,
fp16
b
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
fp16
r
;
asm
volatile
(
"max.xorsign.abs.f16 %0, %1, %2;"
:
"=h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
r
))
:
"h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
a
)),
"h"
(
*
reinterpret_cast
<
int16_t
*>
(
&
b
)));
return
r
;
#else
NVTE_DEVICE_ERROR
(
"get_amax is only supported on SM 10.0+."
);
return
0.
f
;
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
const
bf16x4
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
bf16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
bf16x2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"prmt.b32 val2, 0x0, %1, 0x7632;
\n\t
"
"prmt.b32 val1, 0x0, %1, 0x5410;
\n\t
"
"prmt.b32 val4, 0x0, %2, 0x7632;
\n\t
"
"prmt.b32 val3, 0x0, %2, 0x5410;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
const
bf16x4
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
bf16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
bf16x2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"prmt.b32 val2, 0x0, %1, 0x7632;
\n\t
"
"prmt.b32 val1, 0x0, %1, 0x5410;
\n\t
"
"prmt.b32 val4, 0x0, %2, 0x7632;
\n\t
"
"prmt.b32 val3, 0x0, %2, 0x5410;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
const
bf16x4
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
bf16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
bf16x2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"prmt.b32 val2, 0x0, %1, 0x7632;
\n\t
"
"prmt.b32 val1, 0x0, %1, 0x5410;
\n\t
"
"prmt.b32 val4, 0x0, %2, 0x7632;
\n\t
"
"prmt.b32 val3, 0x0, %2, 0x5410;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
const
bf16x4
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
bf16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
bf16x2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"prmt.b32 val2, 0x0, %1, 0x7632;
\n\t
"
"prmt.b32 val1, 0x0, %1, 0x5410;
\n\t
"
"prmt.b32 val4, 0x0, %2, 0x7632;
\n\t
"
"prmt.b32 val3, 0x0, %2, 0x5410;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
const
fp16x4
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
fp16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
fp16x2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b16 val1_f16;
\n\t
"
".reg.b16 val2_f16;
\n\t
"
".reg.b16 val3_f16;
\n\t
"
".reg.b16 val4_f16;
\n\t
"
"mov.b32 {val1_f16, val2_f16}, %1;
\n\t
"
"mov.b32 {val3_f16, val4_f16}, %2;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"cvt.f32.f16 val1, val1_f16;
\n\t
"
"cvt.f32.f16 val2, val2_f16;
\n\t
"
"cvt.f32.f16 val3, val3_f16;
\n\t
"
"cvt.f32.f16 val4, val4_f16;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
const
fp16x4
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
fp16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
fp16x2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b16 val1_f16;
\n\t
"
".reg.b16 val2_f16;
\n\t
"
".reg.b16 val3_f16;
\n\t
"
".reg.b16 val4_f16;
\n\t
"
"mov.b32 {val1_f16, val2_f16}, %1;
\n\t
"
"mov.b32 {val3_f16, val4_f16}, %2;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"cvt.f32.f16 val1, val1_f16;
\n\t
"
"cvt.f32.f16 val2, val2_f16;
\n\t
"
"cvt.f32.f16 val3, val3_f16;
\n\t
"
"cvt.f32.f16 val4, val4_f16;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
const
fp16x4
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
fp16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
fp16x2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b16 val1_f16;
\n\t
"
".reg.b16 val2_f16;
\n\t
"
".reg.b16 val3_f16;
\n\t
"
".reg.b16 val4_f16;
\n\t
"
"mov.b32 {val1_f16, val2_f16}, %1;
\n\t
"
"mov.b32 {val3_f16, val4_f16}, %2;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"cvt.f32.f16 val1, val1_f16;
\n\t
"
"cvt.f32.f16 val2, val2_f16;
\n\t
"
"cvt.f32.f16 val3, val3_f16;
\n\t
"
"cvt.f32.f16 val4, val4_f16;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %3, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
const
fp16x4
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
fp16x2
const
*
in2
=
reinterpret_cast
<
ptx
::
fp16x2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b16 val1_f16;
\n\t
"
".reg.b16 val2_f16;
\n\t
"
".reg.b16 val3_f16;
\n\t
"
".reg.b16 val4_f16;
\n\t
"
"mov.b32 {val1_f16, val2_f16}, %1;
\n\t
"
"mov.b32 {val3_f16, val4_f16}, %2;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"cvt.f32.f16 val1, val1_f16;
\n\t
"
"cvt.f32.f16 val2, val2_f16;
\n\t
"
"cvt.f32.f16 val3, val3_f16;
\n\t
"
"cvt.f32.f16 val4, val4_f16;
\n\t
"
".reg.b64 val_1_2;
\n\t
"
".reg.b64 val_3_4;
\n\t
"
"mov.b64 val_1_2, {val1, val2};
\n\t
"
"mov.b64 val_3_4, {val3, val4};
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
"fma.rn.f32x2 val_1_2, val_1_2, %3, zeros;
\n\t
"
"fma.rn.f32x2 val_3_4, val_3_4, %4, zeros;
\n\t
"
"mov.b64 {val1, val2}, val_1_2;
\n\t
"
"mov.b64 {val3, val4}, val_3_4;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
0
])),
"r"
(
reinterpret_cast
<
const
uint32_t
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
floatx4
const
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
floatx2
const
*
in2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
".reg.b64 re1;
\n\t
"
".reg.b64 re2;
\n\t
"
"fma.rn.f32x2 re1, %1, %3, zeros;
\n\t
"
"fma.rn.f32x2 re2, %2, %3, zeros;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"mov.b64 {val1, val2}, re1;
\n\t
"
"mov.b64 {val3, val4}, re2;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
0
])),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e5m2x4
&
out
,
floatx4
const
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
floatx2
const
*
in2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
".reg.b64 re1;
\n\t
"
".reg.b64 re2;
\n\t
"
"fma.rn.f32x2 re1, %1, %3, zeros;
\n\t
"
"fma.rn.f32x2 re2, %2, %4, zeros;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"mov.b64 {val1, val2}, re1;
\n\t
"
"mov.b64 {val3, val4}, re2;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e5m2x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e5m2x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
0
])),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
floatx4
const
&
in
,
const
ptx
::
floatx2
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
floatx2
const
*
in2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
".reg.b64 re1;
\n\t
"
".reg.b64 re2;
\n\t
"
"fma.rn.f32x2 re1, %1, %3, zeros;
\n\t
"
"fma.rn.f32x2 re2, %2, %3, zeros;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"mov.b64 {val1, val2}, re1;
\n\t
"
"mov.b64 {val3, val4}, re2;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
0
])),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale
)),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
mul_cvt_4x
(
fp8e4m3x4
&
out
,
floatx4
const
&
in
,
const
floatx4
&
scale
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
ptx
::
floatx2
const
*
in2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
in
);
ptx
::
floatx2
const
*
scale2
=
reinterpret_cast
<
ptx
::
floatx2
const
*>
(
&
scale
);
asm
volatile
(
"{
\n\t
"
".reg.b64 zeros;
\n\t
"
"mov.b64 zeros, {0x0, 0x0};
\n\t
"
".reg.b64 re1;
\n\t
"
".reg.b64 re2;
\n\t
"
"fma.rn.f32x2 re1, %1, %3, zeros;
\n\t
"
"fma.rn.f32x2 re2, %2, %4, zeros;
\n\t
"
".reg.b32 val1;
\n\t
"
".reg.b32 val2;
\n\t
"
".reg.b32 val3;
\n\t
"
".reg.b32 val4;
\n\t
"
"mov.b64 {val1, val2}, re1;
\n\t
"
"mov.b64 {val3, val4}, re2;
\n\t
"
#if (defined _LOOSE_PRECISION)
"cvt.rs.satfinite.e4m3x4.f32 %0, {val4, val3, val2, val1}, %4;
\n\t
"
#else
".reg.b16 r1;
\n\t
"
".reg.b16 r2;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r1, val2, val1;
\n\t
"
"cvt.rn.satfinite.e4m3x2.f32 r2, val4, val3;
\n\t
"
"mov.b32 %0, {r1, r2};
\n\t
"
#endif
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
uint32_t
&>
(
out
))
:
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
0
])),
"l"
(
reinterpret_cast
<
uint64_t
const
&>
(
in2
[
1
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
0
])),
"l"
(
reinterpret_cast
<
const
uint64_t
&>
(
scale2
[
1
])),
"r"
(
0x80008000
));
#else
NVTE_DEVICE_ERROR
(
"mul_cvt_4x is only supported on SM 10.0+."
);
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
__device__
__forceinline__
void
abs_max_2x
(
float
&
dst
,
const
float
&
p1
,
const
float
&
p2
,
const
float
&
p3
)
{
#if (defined CUDA_VERSION) && (CUDA_VERSION >= 12090)
asm
volatile
(
"max.abs.f32 %0, %1, %2, %3;"
:
"=f"
(
dst
)
:
"f"
(
p1
),
"f"
(
p2
),
"f"
(
p3
));
#else
asm
volatile
(
"max.xorsign.abs.f32 %0, %2, %3;"
"max.xorsign.abs.f32 %0, %0, %1;"
:
"+f"
(
dst
)
:
"f"
(
p1
),
"f"
(
p2
),
"f"
(
p3
));
#endif
}
__device__
__forceinline__
ptx
::
floatx2
up_cast
(
const
ptx
::
fp16x2
&
in
)
{
ptx
::
floatx2
out
;
asm
volatile
(
"{
\n\t
"
".reg.b16 f16_1;
\n\t
"
".reg.b16 f16_2;
\n\t
"
"mov.b32 {f16_1, f16_2}, %2;
\n\t
"
"cvt.f32.f16 %0, f16_1;
\n\t
"
"cvt.f32.f16 %1, f16_2;
\n\t
"
"}
\n\t
"
:
"=f"
(
out
.
x
),
"=f"
(
out
.
y
)
:
"r"
(
reinterpret_cast
<
int32_t
const
&>
(
in
)));
return
out
;
}
__device__
__forceinline__
floatx4
up_cast
(
const
fp16x4
&
in
)
{
floatx4
out
;
asm
volatile
(
"{
\n\t
"
".reg.b16 f16_1;
\n\t
"
".reg.b16 f16_2;
\n\t
"
".reg.b16 f16_3;
\n\t
"
".reg.b16 f16_4;
\n\t
"
"mov.b64 {f16_1, f16_2, f16_3, f16_4}, %4;
\n\t
"
"cvt.f32.f16 %0, f16_1;
\n\t
"
"cvt.f32.f16 %1, f16_2;
\n\t
"
"cvt.f32.f16 %2, f16_3;
\n\t
"
"cvt.f32.f16 %3, f16_4;
\n\t
"
"}
\n\t
"
:
"=f"
(
out
.
x1
),
"=f"
(
out
.
x2
),
"=f"
(
out
.
x3
),
"=f"
(
out
.
x4
)
:
"l"
(
reinterpret_cast
<
int64_t
const
&>
(
in
)));
return
out
;
}
__device__
__forceinline__
ptx
::
floatx2
up_cast
(
const
ptx
::
bf16x2
&
in
)
{
ptx
::
floatx2
out
;
asm
volatile
(
"{
\n\t
"
"prmt.b32 %1, 0x0, %2, 0x7632;
\n\t
"
"prmt.b32 %0, 0x0, %2, 0x5410;
\n\t
"
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
x
)),
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
y
))
:
"r"
(
reinterpret_cast
<
int32_t
const
&>
(
in
)));
return
out
;
}
__device__
__forceinline__
floatx4
up_cast
(
const
bf16x4
&
in
)
{
floatx4
out
;
int32_t
const
*
in2
=
reinterpret_cast
<
int32_t
const
*>
(
&
in
);
asm
volatile
(
"{
\n\t
"
"prmt.b32 %1, 0x0, %4, 0x7632;
\n\t
"
"prmt.b32 %0, 0x0, %4, 0x5410;
\n\t
"
"prmt.b32 %3, 0x0, %5, 0x7632;
\n\t
"
"prmt.b32 %2, 0x0, %5, 0x5410;
\n\t
"
"}
\n\t
"
:
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
x1
)),
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
x2
)),
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
x3
)),
"=r"
(
reinterpret_cast
<
int32_t
&>
(
out
.
x4
))
:
"r"
(
in2
[
0
]),
"r"
(
in2
[
1
]));
return
out
;
}
#endif
}
// namespace ptx
namespace
{
...
...
transformer_engine/common/util/pybind_helper.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -88,7 +88,8 @@
pybind11::enum_<transformer_engine::Float8BlockScaleTensorFormat>( \
m, "Float8BlockScaleTensorFormat", pybind11::module_local()) \
.value("GEMM_READY", transformer_engine::Float8BlockScaleTensorFormat::GEMM_READY) \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT); \
.value("COMPACT", transformer_engine::Float8BlockScaleTensorFormat::COMPACT) \
.value("INVALID", transformer_engine::Float8BlockScaleTensorFormat::INVALID); \
pybind11::enum_<transformer_engine::CommOverlapType>(m, "CommOverlapType", \
pybind11::module_local()) \
.value("RS", transformer_engine::CommOverlapType::RS) \
...
...
transformer_engine/common/util/rtc.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/rtc.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/shared_lib_wrapper.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/string.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/string_header.h.in
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/system.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/vectorized_pointwise.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/utils.cuh
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/utils.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""The utilities for Transformer Engine"""
...
...
transformer_engine/debug/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/_test_dummy_feature.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/api.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/disable_fp8_gemm.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/disable_fp8_layer.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/debug/features/fake_quant.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment