Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
4927d10e
Commit
4927d10e
authored
Aug 26, 2025
by
yuguo
Browse files
[DCU] fix
parent
2e870ed9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
6 deletions
+14
-6
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
...mer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
+14
-6
No files found.
transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp
View file @
4927d10e
...
@@ -117,13 +117,17 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
...
@@ -117,13 +117,17 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
}
}
const
char
*
NVTE_UB_COMM_CU_NUMS
=
std
::
getenv
(
"NVTE_UB_COMM_CU_NUMS"
);
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
static
cudaStream_t
compute_streams
[
NVTE_COMM_OVERLAP_MAX_STREAMS
];
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
for
(
int
i
=
0
;
i
<
std
::
min
(
num_max_streams
,
num_splits
);
i
++
)
{
if
(
compute_streams
[
i
]
==
nullptr
)
{
if
(
compute_streams
[
i
]
==
nullptr
)
{
NVTE_
CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
if
(
NVTE_
UB_COMM_CU_NUMS
!=
nullptr
&&
NVTE_UB_COMM_CU_NUMS
[
0
]
!=
'\0'
)
{
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
compute_streams
[
i
],
cuMaskSize
,
cuMask
));
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
compute_streams
[
i
],
cuMaskSize
,
cuMask
));
#endif
#endif
}
else
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
compute_streams
[
i
],
cudaStreamNonBlocking
,
_gemm_priority
));
}
}
}
_stream_compute
.
push_back
(
compute_streams
[
i
]);
_stream_compute
.
push_back
(
compute_streams
[
i
]);
}
}
...
@@ -359,14 +363,18 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
...
@@ -359,14 +363,18 @@ CommOverlapBase::CommOverlapBase(const std::vector<size_t> &buffer_shape, DType
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
NVTE_CHECK
(
false
,
"comm_cu_nums must be 4,8,16,32"
);
}
}
const
char
*
NVTE_UB_COMM_CU_NUMS
=
std
::
getenv
(
"NVTE_UB_COMM_CU_NUMS"
);
static
cudaStream_t
comm_stream
;
static
cudaStream_t
comm_stream
;
if
(
comm_stream
==
nullptr
)
{
if
(
comm_stream
==
nullptr
)
{
NVTE_CHECK_CUDA
(
if
(
NVTE_UB_COMM_CU_NUMS
!=
nullptr
&&
NVTE_UB_COMM_CU_NUMS
[
0
]
!=
'\0'
)
{
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
#ifdef __HIP_PLATFORM_AMD__
#ifdef __HIP_PLATFORM_AMD__
NVTE_CHECK_CUDA
(
NVTE_CHECK_CUDA
(
hipExtStreamCreateWithCUMask
(
&
comm_stream
,
cuMaskSize
,
cuMask
));
hipExtStreamCreateWithCUMask
(
&
comm_stream
,
cuMaskSize
,
cuMask
));
#endif
#endif
}
else
{
NVTE_CHECK_CUDA
(
cudaStreamCreateWithPriority
(
&
comm_stream
,
cudaStreamNonBlocking
,
_comm_priority
));
}
}
}
_stream_comm
=
comm_stream
;
_stream_comm
=
comm_stream
;
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
NVTE_CHECK_CUDA
(
cudaEventCreateWithFlags
(
&
_start_d2dcopy
,
0
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment