Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
132 additions
and
63 deletions
+132
-63
transformer_engine/common/transpose/rtc/transpose.cu
transformer_engine/common/transpose/rtc/transpose.cu
+1
-1
transformer_engine/common/transpose/swap_first_dims.cu
transformer_engine/common/transpose/swap_first_dims.cu
+1
-1
transformer_engine/common/transpose/transpose.cu
transformer_engine/common/transpose/transpose.cu
+9
-5
transformer_engine/common/transpose/transpose.h
transformer_engine/common/transpose/transpose.h
+20
-0
transformer_engine/common/transpose/transpose_fusion.cu
transformer_engine/common/transpose/transpose_fusion.cu
+1
-3
transformer_engine/common/triton/__init__.py
transformer_engine/common/triton/__init__.py
+1
-1
transformer_engine/common/triton/cross_entropy.py
transformer_engine/common/triton/cross_entropy.py
+13
-3
transformer_engine/common/triton/pad.py
transformer_engine/common/triton/pad.py
+1
-1
transformer_engine/common/triton/permutation.py
transformer_engine/common/triton/permutation.py
+74
-37
transformer_engine/common/util/cuda_driver.cpp
transformer_engine/common/util/cuda_driver.cpp
+1
-1
transformer_engine/common/util/cuda_driver.h
transformer_engine/common/util/cuda_driver.h
+1
-1
transformer_engine/common/util/cuda_nvml.cpp
transformer_engine/common/util/cuda_nvml.cpp
+1
-1
transformer_engine/common/util/cuda_nvml.h
transformer_engine/common/util/cuda_nvml.h
+1
-1
transformer_engine/common/util/cuda_runtime.cpp
transformer_engine/common/util/cuda_runtime.cpp
+1
-1
transformer_engine/common/util/cuda_runtime.h
transformer_engine/common/util/cuda_runtime.h
+1
-1
transformer_engine/common/util/curanddx.hpp
transformer_engine/common/util/curanddx.hpp
+1
-1
transformer_engine/common/util/handle_manager.h
transformer_engine/common/util/handle_manager.h
+1
-1
transformer_engine/common/util/logging.h
transformer_engine/common/util/logging.h
+1
-1
transformer_engine/common/util/math.h
transformer_engine/common/util/math.h
+1
-1
transformer_engine/common/util/multi_stream.cpp
transformer_engine/common/util/multi_stream.cpp
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/common/transpose/rtc/transpose.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/swap_first_dims.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/transpose/transpose.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -14,8 +14,10 @@
#include "../util/rtc.h"
#include "../util/string.h"
#include "../utils.cuh"
#include "./transpose.h"
namespace
transformer_engine
{
namespace
detail
{
namespace
{
...
...
@@ -203,7 +205,8 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
NVTE_CHECK
(
input
.
data
.
dptr
!=
nullptr
,
"Input is not allocated."
);
NVTE_CHECK
(
output
.
data
.
dptr
!=
nullptr
,
"Output is not allocated."
);
NVTE_CHECK
(
input
.
data
.
dtype
==
output
.
data
.
dtype
,
"Input and output type must match."
);
NVTE_CHECK
(
input
.
data
.
dtype
==
output
.
data
.
dtype
,
"Input (dtype="
,
to_string
(
input
.
data
.
dtype
),
") and output (dtype="
,
to_string
(
output
.
data
.
dtype
),
") do not match."
);
if
(
noop
.
data
.
dptr
!=
nullptr
)
{
NVTE_CHECK
(
noop
.
numel
()
==
1
,
"Expected 1 element, "
,
"but found "
,
noop
.
numel
(),
"."
);
...
...
@@ -289,19 +292,20 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr
});
// NOLINT(*)
}
}
// namespace detail
}
// namespace transformer_engine
void
nvte_transpose
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_transpose
);
using
namespace
transformer_engine
;
auto
noop
=
Tensor
();
transpose
(
*
convertNVTETensorCheck
(
input
),
noop
,
convertNVTETensor
(
output
),
stream
);
detail
::
transpose
(
*
convertNVTETensorCheck
(
input
),
noop
,
convertNVTETensor
(
output
),
stream
);
}
void
nvte_transpose_with_noop
(
const
NVTETensor
input
,
const
NVTETensor
noop
,
NVTETensor
output
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_transpose_with_noop
);
using
namespace
transformer_engine
;
transpose
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
noop
),
convertNVTETensor
(
output
),
stream
);
detail
::
transpose
(
*
convertNVTETensorCheck
(
input
),
*
convertNVTETensorCheck
(
noop
),
convertNVTETensor
(
output
),
stream
);
}
transformer_engine/common/transpose/transpose.h
0 → 100644
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
#include "../common.h"
namespace
transformer_engine
{
namespace
detail
{
void
transpose
(
const
Tensor
&
input
,
const
Tensor
&
noop
,
Tensor
*
output_
,
cudaStream_t
stream
);
}
// namespace detail
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_TRANSPOSE_H_
transformer_engine/common/transpose/transpose_fusion.cu
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
@@ -392,8 +392,6 @@ void populate_transpose_dbias_workspace_config(const Tensor &input, /*cast*/
workspace
->
data
.
dtype
);
const
size_t
required_size
=
get_buffer_size_bytes
(
num_rows_partial_dbias
,
row_length
,
DType
::
kFloat32
);
NVTE_CHECK
(
!
workspace
->
data
.
shape
.
empty
(),
"Invalid workspace dims (expected ("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), found ())"
);
NVTE_CHECK
(
workspace_size
>=
required_size
,
"Invalid workspace (expected dims=("
,
num_rows_partial_dbias
,
","
,
row_length
,
"), dtype="
,
to_string
(
DType
::
kFloat32
),
"; found dims="
,
workspace
->
data
.
shape
,
...
...
transformer_engine/common/triton/__init__.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/common/triton/cross_entropy.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -18,6 +18,8 @@ def online_softmax_kernel(
m_d_X_y_stride
,
rank
,
n_cols
,
ignore_idx
,
n_non_ignore
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""
...
...
@@ -32,6 +34,8 @@ def online_softmax_kernel(
m_d_X_y_stride (int): The stride of the m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
n_cols (int): The number of columns in the input tensor.
ignore_idx (int): The index to ignore for loss calculation.
n_non_ignore: The number of non-ignored elements in the batch.
BLOCK_SIZE (int): The block size for Triton operations.
"""
...
...
@@ -44,6 +48,9 @@ def online_softmax_kernel(
Y_ptr
+=
program_id
*
Y_stride
y
=
tl
.
load
(
Y_ptr
)
if
y
!=
ignore_idx
:
tl
.
atomic_add
(
n_non_ignore
,
1
)
vocab_start_idx
=
rank
*
n_cols
vocab_end_idx
=
(
rank
+
1
)
*
n_cols
if
y
>=
vocab_start_idx
:
...
...
@@ -89,6 +96,7 @@ def cross_entropy_kernel(
world_size
,
ignore_idx
,
n_cols
,
n_rows
,
n_non_ignore
,
reduce_loss
:
tl
.
constexpr
,
label_smoothing
:
tl
.
constexpr
,
...
...
@@ -110,12 +118,14 @@ def cross_entropy_kernel(
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
n_rows (int): The number of rows in the batch (B * SQ), used for buffer indexing.
n_non_ignore: The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
BLOCK_SIZE (int): The block size for Triton operations.
"""
program_id
=
tl
.
program_id
(
0
).
to
(
tl
.
int64
)
n_non_ignore
=
tl
.
load
(
n_non_ignore
)
# locate the start index
X_ptr
+=
program_id
*
X_stride
...
...
@@ -140,7 +150,7 @@ def cross_entropy_kernel(
ori_X_y
=
tl
.
load
(
m_d_X_y_ptr
+
(
2
*
m_d_X_y_stride
))
for
i
in
range
(
1
,
world_size
):
offset
=
i
*
3
*
n_
non_ignore
*
m_d_X_y_stride
offset
=
i
*
3
*
n_
rows
*
m_d_X_y_stride
access_ptr
=
m_d_X_y_ptr
+
offset
m_new
=
tl
.
load
(
access_ptr
)
d_new
=
tl
.
load
(
access_ptr
+
m_d_X_y_stride
)
...
...
transformer_engine/common/triton/pad.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
transformer_engine/common/triton/permutation.py
View file @
0d874a4e
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
...
...
@@ -81,10 +81,8 @@ def _argsort(x, indices, n_dims: tl.constexpr):
@
triton
.
jit
def
_row_id_map_pass_1_kernel
(
# pointers
#
input
pointers
routing_map_ptr
,
row_id_map_ptr
,
workspace_ptr
,
# sizes
num_tokens
,
# strides
...
...
@@ -92,6 +90,9 @@ def _row_id_map_pass_1_kernel(
stride_routing_map_expert
,
stride_row_id_map_token
,
stride_row_id_map_expert
,
# output pointers
row_id_map_ptr
,
workspace_ptr
,
# metas
BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -155,12 +156,11 @@ def _row_id_map_pass_2_kernel(
def
_row_id_map_pass_3_kernel
(
# pointers
row_id_map_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
# metas
num_experts
:
tl
.
constexpr
,
LOAD_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
0
)
...
...
@@ -194,18 +194,22 @@ def _row_id_map_pass_3_kernel(
@
triton
.
jit
def
_permute_kernel
(
# pointers
#
input
pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
probs_ptr
,
scale_ptr
,
permuted_probs_ptr
,
permuted_scale_ptr
,
pad_offsets_ptr
,
# Pre-allocated output buffers for JAX input_output_aliases.
# These are aliased to output_ptr/permuted_probs_ptr in JAX, so they point to the same memory.
# In PyTorch, pass the same tensors as output_ptr/permuted_probs_ptr.
output_buf_ptr
,
# pylint: disable=unused-argument
permuted_probs_buf_ptr
,
# pylint: disable=unused-argument
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
scale_hidden_dim
,
num_tokens
,
# pylint: disable=unused-argument
num_out_tokens
,
# pylint: disable=unused-argument
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
...
...
@@ -220,15 +224,28 @@ def _permute_kernel(
stride_permuted_probs_token
,
stride_permuted_scale_token
,
stride_permuted_scale_hidden
,
# output pointers
output_ptr
,
permuted_probs_ptr
,
# metas
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
PERMUTE_SCALE
:
tl
.
constexpr
,
FUSION_PAD
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
# Note: When FUSION_PAD=True, output buffers should be pre-zeroed by the caller
# to ensure padding positions contain zeros.
# PyTorch: Use torch.zeros() for output buffer allocation
# JAX: Pre-zeroed buffers should be passed (when input_output_aliases works)
expert_idx
=
0
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
cur_off
=
pid_h
*
BLOCK_SIZE
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cur_off
<
hidden_size
src_row
=
pid_t
.
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
cur_off
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
...
...
@@ -245,6 +262,15 @@ def _permute_kernel(
dst_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
).
to
(
tl
.
int64
)
if
FUSION_PAD
or
PERMUTE_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
if
FUSION_PAD
:
pad_off
=
tl
.
load
(
pad_offsets_ptr
+
expert_idx
)
dst_row
=
dst_row
+
pad_off
output_off
=
dst_row
*
stride_output_token
+
cur_off
*
stride_output_hidden
if
PERMUTE_SCALE
:
permuted_scale_off
=
(
...
...
@@ -252,11 +278,6 @@ def _permute_kernel(
)
tl
.
store
(
permuted_scale_ptr
+
permuted_scale_off
,
scale
,
mask
=
mask_scale
)
if
PERMUTE_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
prob_off
=
pid_t
*
stride_probs_token
+
expert_idx
*
stride_probs_expert
prob
=
tl
.
load
(
probs_ptr
+
prob_off
)
if
pid_h
==
0
:
...
...
@@ -291,16 +312,16 @@ except RuntimeError:
@
triton
.
jit
def
_unpermute_kernel
(
# pointers
#
input
pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
merging_probs_ptr
,
permuted_probs_ptr
,
unpermuted_probs_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
pad_offsets_ptr
,
# Dummy parameters for JAX input_output_aliases compatibility (matches _permute_kernel signature pattern)
# These are unused in the unpermute kernel but maintain consistency with the permute kernel.
output_buf_ptr
,
# pylint: disable=unused-argument
unpermuted_probs_buf_ptr
,
# pylint: disable=unused-argument
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
...
...
@@ -313,14 +334,21 @@ def _unpermute_kernel(
stride_permuted_probs_token
,
stride_unpermuted_probs_token
,
stride_unpermuted_probs_expert
,
# output pointers
output_ptr
,
unpermuted_probs_ptr
,
# metas
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
WITH_MERGING_PROBS
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
FUSION_UNPAD
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
data_type
=
input_ptr
.
dtype
.
element_ty
compute_type
=
tl
.
float32
expert_idx
=
0
pid_t
=
tl
.
program_id
(
0
)
pid_h
=
tl
.
program_id
(
1
)
...
...
@@ -347,15 +375,19 @@ def _unpermute_kernel(
src_row
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
idx
*
stride_row_id_map_expert
).
to
(
tl
.
int64
)
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
if
FUSION_UNPAD
or
WITH_MERGING_PROBS
:
expert_idx
=
tl
.
load
(
row_id_map_ptr
+
pid_t
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
if
FUSION_UNPAD
:
pad_off
=
tl
.
load
(
pad_offsets_ptr
+
expert_idx
)
src_row
=
src_row
+
pad_off
input_off
=
src_row
*
stride_input_token
+
current_offset
*
stride_input_hidden
inp
=
tl
.
load
(
input_ptr
+
input_off
,
mask
=
mask
)
inp
=
inp
.
to
(
compute_type
)
if
WITH_MERGING_PROBS
:
merging_prob_off
=
(
pid_t
*
stride_merging_probs_token
+
expert_idx
*
stride_merging_probs_expert
)
...
...
@@ -401,16 +433,12 @@ except RuntimeError:
@
triton
.
jit
def
_unpermute_bwd_with_merging_probs_kernel
(
# pointers
#
input
pointers
fwd_output_grad_ptr
,
fwd_input_grad_ptr
,
fwd_input_ptr
,
merging_probs_ptr
,
merging_probs_grad_ptr
,
row_id_map_ptr
,
# sizes
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
pad_offsets_ptr
,
# strides
stride_row_id_map_token
,
stride_row_id_map_expert
,
...
...
@@ -424,8 +452,14 @@ def _unpermute_bwd_with_merging_probs_kernel(
stride_merging_probs_expert
,
stride_merging_probs_grad_token
,
stride_merging_probs_grad_expert
,
# output pointers
fwd_input_grad_ptr
,
merging_probs_grad_ptr
,
# metas
num_experts
:
tl
.
constexpr
,
hidden_size
:
tl
.
constexpr
,
PROBS_LOAD_WIDTH
:
tl
.
constexpr
,
FUSION_UNPAD
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
data_type
=
fwd_output_grad_ptr
.
dtype
.
element_ty
...
...
@@ -449,6 +483,9 @@ def _unpermute_bwd_with_merging_probs_kernel(
+
pid
*
stride_row_id_map_token
+
(
num_experts
+
idx
)
*
stride_row_id_map_expert
)
if
FUSION_UNPAD
:
pad_off
=
tl
.
load
(
pad_offsets_ptr
+
expert_idx
)
dst_row
=
dst_row
+
pad_off
prob_grad_accum
=
tl
.
zeros
((
BLOCK_SIZE
,),
dtype
=
compute_type
)
current_start
=
0
while
current_start
<
hidden_size
:
...
...
@@ -546,14 +583,10 @@ def _make_chunk_sort_map_kernel(
@
triton
.
jit
def
_sort_chunks_by_map_kernel
(
# pointers
#
input
pointers
input_ptr
,
output_ptr
,
row_id_map_ptr
,
probs_ptr
,
permuted_probs_ptr
,
# sizes
hidden_size
:
tl
.
constexpr
,
# strides
stride_input_token
,
stride_input_hidden
,
...
...
@@ -561,7 +594,11 @@ def _sort_chunks_by_map_kernel(
stride_output_hidden
,
stride_probs_token
,
stride_permuted_probs_token
,
# output pointers
output_ptr
,
permuted_probs_ptr
,
# metas
hidden_size
:
tl
.
constexpr
,
PERMUTE_PROBS
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
FORWARD
:
tl
.
constexpr
,
...
...
transformer_engine/common/util/cuda_driver.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/cuda_driver.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/cuda_nvml.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/cuda_nvml.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/cuda_runtime.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/cuda_runtime.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/curanddx.hpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/handle_manager.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/logging.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/math.h
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
transformer_engine/common/util/multi_stream.cpp
View file @
0d874a4e
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
...
...
Prev
1
…
17
18
19
20
21
22
23
24
25
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment