Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
57deee08
Commit
57deee08
authored
Mar 31, 2025
by
yuguo
Browse files
[DCU] cpp test compile pass
parent
ab122dac
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
18 additions
and
2 deletions
+18
-2
tests/cpp/util/CMakeLists.txt
tests/cpp/util/CMakeLists.txt
+11
-0
transformer_engine/pytorch/csrc/type_shim.h
transformer_engine/pytorch/csrc/type_shim.h
+1
-2
transformer_engine/pytorch/module/batched_linear.py
transformer_engine/pytorch/module/batched_linear.py
+6
-0
No files found.
tests/cpp/util/CMakeLists.txt
View file @
57deee08
...
...
@@ -2,6 +2,7 @@
#
# See LICENSE for license information.
if
(
USE_CUDA
)
add_executable
(
test_util
test_nvrtc.cpp
test_string.cpp
...
...
@@ -10,6 +11,16 @@ add_executable(test_util
find_package
(
OpenMP REQUIRED
)
target_link_libraries
(
test_util PUBLIC CUDA::cudart GTest::gtest_main
${
TE_LIB
}
CUDA::nvrtc CUDNN::cudnn OpenMP::OpenMP_CXX
)
else
()
add_executable
(
test_util
test_nvrtc_hip.cpp
test_string.cpp
../test_common.hip
)
find_package
(
OpenMP REQUIRED
)
target_link_libraries
(
test_util PUBLIC hip::host hip::device GTest::gtest_main
${
TE_LIB
}
OpenMP::OpenMP_CXX
)
endif
()
target_compile_options
(
test_util PRIVATE -O2 -fopenmp
)
include
(
GoogleTest
)
...
...
transformer_engine/pytorch/csrc/type_shim.h
View file @
57deee08
...
...
@@ -6,6 +6,7 @@
#pragma once
#include <ATen/ATen.h>
#include "common/utils.cuh"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
...
...
@@ -267,8 +268,6 @@
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
constexpr
uint32_t
THREADS_PER_WARP
=
32
;
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
...
...
transformer_engine/pytorch/module/batched_linear.py
View file @
57deee08
...
...
@@ -91,6 +91,12 @@ class _BatchedLinear(torch.autograd.Function):
# TODO Support Float8 Current Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_current_scaling
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Current Scaling"
)
# TODO Support Float8 Delayed Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
delayed
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Delayed Scaling"
)
# TODO Support Float8 Per Tensor Scaling # pylint: disable=fixme
if
fp8
and
FP8GlobalStateManager
.
get_fp8_recipe
().
float8_per_tensor_scaling
():
raise
NotImplementedError
(
"BatchedLinear does not yet support Float8 Per Tensor Scaling"
)
# Make sure input dimensions are compatible
in_features
=
weights
[
0
].
shape
[
-
1
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment