Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
0d874a4e
Commit
0d874a4e
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main' of v2.12
parents
a68e5f87
dfdd3820
Changes
640
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
55 additions
and
33 deletions
+55
-33
transformer_engine/common/multi_tensor/scale.cu
transformer_engine/common/multi_tensor/scale.cu
+1
-1
transformer_engine/common/multi_tensor/sgd.cu
transformer_engine/common/multi_tensor/sgd.cu
+1
-1
transformer_engine/common/normalization/common.cpp
transformer_engine/common/normalization/common.cpp
+16
-4
transformer_engine/common/normalization/common.h
transformer_engine/common/normalization/common.h
+1
-1
transformer_engine/common/normalization/kernel_traits.h
transformer_engine/common/normalization/kernel_traits.h
+1
-1
transformer_engine/common/normalization/layernorm/ln_api.cpp
transformer_engine/common/normalization/layernorm/ln_api.cpp
+10
-5
transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh
..._engine/common/normalization/layernorm/ln_bwd_kernels.cuh
+1
-1
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
...common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
+1
-1
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
...gine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
+1
-1
transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh
..._engine/common/normalization/layernorm/ln_fwd_kernels.cuh
+1
-1
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
...ormer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
+12
-7
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
...gine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
+1
-1
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
...mon/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
+1
-1
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
...e/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
+1
-1
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
...gine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
+1
-1
transformer_engine/common/nvshmem_api/CMakeLists.txt
transformer_engine/common/nvshmem_api/CMakeLists.txt
+1
-1
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
+1
-1
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
+1
-1
transformer_engine/common/nvtx.h
transformer_engine/common/nvtx.h
+1
-1
transformer_engine/common/permutation/permutation.cu
transformer_engine/common/permutation/permutation.cu
+1
-1
No files found.
Too many changes to show.
To preserve performance only
640 of 640+
files are displayed.
Plain diff
Email patch
transformer_engine/common/multi_tensor/scale.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/multi_tensor/sgd.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/common.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -129,7 +129,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
...
@@ -129,7 +129,13 @@ void TeNormalizationPlan<KernelParamsType>::_build() {
template
<
typename
KernelParamsType
>
template
<
typename
KernelParamsType
>
std
::
vector
<
size_t
>
TeNormalizationPlan
<
KernelParamsType
>::
getWorkspaceShape
()
const
{
std
::
vector
<
size_t
>
TeNormalizationPlan
<
KernelParamsType
>::
getWorkspaceShape
()
const
{
return
{
_launch_params
.
getTotalWorkspaceBytes
(
_is_layernorm
)};
size_t
workspace_size
=
_launch_params
.
getTotalWorkspaceBytes
(
_is_layernorm
);
if
(
workspace_size
==
0
)
{
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size
=
1
;
}
return
{
workspace_size
};
}
}
template
<
typename
KernelParamsType
>
template
<
typename
KernelParamsType
>
...
@@ -418,9 +424,15 @@ void CudnnNormalizationPlan::_build() {
...
@@ -418,9 +424,15 @@ void CudnnNormalizationPlan::_build() {
std
::
vector
<
size_t
>
CudnnNormalizationPlan
::
getWorkspaceShape
()
const
{
std
::
vector
<
size_t
>
CudnnNormalizationPlan
::
getWorkspaceShape
()
const
{
#ifdef USE_ROCM
#ifdef USE_ROCM
assert
(
false
);
assert
(
false
);
return
{
0
};
return
{
1
};
#else
#else
return
{
static_cast
<
size_t
>
(
_graph
.
get_workspace_size
())};
size_t
workspace_size
=
_graph
.
get_workspace_size
();
if
(
workspace_size
==
0
)
{
// Workspace size must not be zero since that corresponds to a
// workspace size query
workspace_size
=
1
;
}
return
{
workspace_size
};
#endif
#endif
}
}
...
...
transformer_engine/common/normalization/common.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/kernel_traits.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/layernorm/ln_api.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -27,10 +27,15 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
const
float
epsilon
,
Tensor
*
z
,
Tensor
*
mu
,
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
float
epsilon
,
Tensor
*
z
,
Tensor
*
mu
,
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
// Check for unsupported configurations
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
}
if
(
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
NVTE_CHECK
(
!
z
->
with_gemm_swizzled_scales
,
"MXFP8 output must have scales in compact format, not swizzled for GEMM."
);
}
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
NVTE_CHECK
(
gamma
.
data
.
shape
==
beta
.
data
.
shape
,
"Gamma and Beta must have the same shape."
);
NVTE_CHECK
(
gamma
.
data
.
shape
==
beta
.
data
.
shape
,
"Gamma and Beta must have the same shape."
);
...
@@ -51,7 +56,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -51,7 +56,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
if
(
!
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
!=
0
)
{
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
gamma
,
"gamma"
);
CheckInputTensor
(
gamma
,
"gamma"
);
CheckInputTensor
(
beta
,
"beta"
);
CheckInputTensor
(
beta
,
"beta"
);
...
@@ -101,7 +106,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
...
@@ -101,7 +106,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
gamma_in_weight_dtype
);
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
==
0
)
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
return
;
...
@@ -153,7 +158,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
...
@@ -153,7 +158,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
NVTE_CHECK
(
dbeta
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dbeta
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dbeta
->
data
.
dtype
==
gamma
.
data
.
dtype
);
NVTE_CHECK
(
dbeta
->
data
.
dtype
==
gamma
.
data
.
dtype
);
if
(
!
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
!=
0
)
{
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
mu
,
"mu"
);
CheckInputTensor
(
mu
,
"mu"
);
...
@@ -186,7 +191,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
...
@@ -186,7 +191,7 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
==
0
)
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
return
;
...
...
transformer_engine/common/normalization/layernorm/ln_bwd_kernels.cuh
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
@@ -23,10 +23,15 @@ using namespace normalization;
...
@@ -23,10 +23,15 @@ using namespace normalization;
void
rmsnorm_fwd
(
const
Tensor
&
x
,
const
Tensor
&
gamma
,
const
float
epsilon
,
Tensor
*
z
,
void
rmsnorm_fwd
(
const
Tensor
&
x
,
const
Tensor
&
gamma
,
const
float
epsilon
,
Tensor
*
z
,
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
Tensor
*
rsigma
,
Tensor
*
workspace
,
const
int
multiprocessorCount
,
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
const
bool
zero_centered_gamma
,
cudaStream_t
stream
)
{
// Check for unsupported configurations
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
if
(
is_fp8_dtype
(
z
->
data
.
dtype
)
&&
!
is_delayed_tensor_scaling
(
z
->
scaling_mode
)
&&
!
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
!
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
z
->
scaling_mode
)
+
"."
);
}
}
if
(
is_mxfp8_scaling
(
z
->
scaling_mode
))
{
NVTE_CHECK
(
!
z
->
with_gemm_swizzled_scales
,
"MXFP8 output must have scales in compact format, not swizzled for GEMM."
);
}
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
NVTE_CHECK
(
x
.
data
.
shape
.
size
()
==
2
,
"x must be 2D tensor."
);
...
@@ -39,7 +44,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -39,7 +44,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
"RSigma must be 1D tensor with shape (x.shape[0],)."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
NVTE_CHECK
(
rsigma
->
data
.
dtype
==
DType
::
kFloat32
,
"RSigma must be a float32 tensor."
);
if
(
!
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
!=
0
)
{
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
gamma
,
"gamma"
);
CheckInputTensor
(
gamma
,
"gamma"
);
...
@@ -86,7 +91,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
...
@@ -86,7 +91,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
z
->
scaling_mode
,
training
,
gamma_in_weight_dtype
);
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
==
0
)
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
return
;
...
@@ -132,7 +137,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
...
@@ -132,7 +137,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
NVTE_CHECK
(
dgamma
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dgamma
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dgamma
->
data
.
dtype
==
gamma
.
data
.
dtype
);
NVTE_CHECK
(
dgamma
->
data
.
dtype
==
gamma
.
data
.
dtype
);
if
(
!
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
!=
0
)
{
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
rsigma
,
"rsigma"
);
CheckInputTensor
(
rsigma
,
"rsigma"
);
...
@@ -163,7 +168,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
...
@@ -163,7 +168,7 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
==
0
)
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
return
;
...
@@ -198,7 +203,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
...
@@ -198,7 +203,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
NVTE_CHECK
(
dgamma
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dgamma
->
data
.
shape
==
gamma
.
data
.
shape
);
NVTE_CHECK
(
dgamma
->
data
.
dtype
==
gamma
.
data
.
dtype
);
NVTE_CHECK
(
dgamma
->
data
.
dtype
==
gamma
.
data
.
dtype
);
if
(
!
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
!=
0
)
{
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
dz
,
"dz"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
x
,
"x"
);
CheckInputTensor
(
add
,
"add"
);
CheckInputTensor
(
add
,
"add"
);
...
@@ -229,7 +234,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
...
@@ -229,7 +234,7 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
multiprocessorCount
,
zero_centered_gamma
,
is_aligned
,
NVTE_DELAYED_TENSOR_SCALING
,
true
,
gamma_in_weight_dtype
);
gamma_in_weight_dtype
);
if
(
workspace
->
data
.
shape
.
empty
()
)
{
if
(
workspace
->
data
.
numel
()
==
0
)
{
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
shape
=
plan
->
getWorkspaceShape
();
workspace
->
data
.
dtype
=
DType
::
kByte
;
workspace
->
data
.
dtype
=
DType
::
kByte
;
return
;
return
;
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/nvshmem_api/CMakeLists.txt
View file @
0d874a4e
##########################################################################
##########################################################################
# Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
##########################################################################
##########################################################################
...
...
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/nvshmem_api/nvshmem_waitkernel.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/nvtx.h
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
transformer_engine/common/permutation/permutation.cu
View file @
0d874a4e
/*************************************************************************
/*************************************************************************
* Copyright (c) 2022-202
5
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2022-202
6
, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
…
32
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment