Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
456a96c8
Commit
456a96c8
authored
Apr 18, 2025
by
yuguo
Browse files
[DCU] overlap bug fix in ECO and BW finally
parent
b9ec4909
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
51 deletions
+8
-51
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
...ngine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
+5
-44
transformer_engine/pytorch/module/base.py
transformer_engine/pytorch/module/base.py
+1
-5
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+2
-2
No files found.
transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu
View file @
456a96c8
...
@@ -80,7 +80,7 @@
...
@@ -80,7 +80,7 @@
printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__)
printf("[%s:%s:%d] " message "\n", FILENAME(__FILE__), __FUNCTION__, __LINE__, __VA_ARGS__)
// Report and error on timeout
// Report and error on timeout
#define CHECK_TIMEOUT(t, timeout) ((clock64() -
(t)) >
timeout)
#define CHECK_TIMEOUT(t, timeout) ((
static_cast<int64_t>(
clock64()
)
-
static_cast<int64_t>(t)) > static_cast<int64_t>(
timeout)
)
template
<
int
RANKS
>
template
<
int
RANKS
>
__global__
void
__launch_bounds__
(
MAX_THREADS
)
__global__
void
__launch_bounds__
(
MAX_THREADS
)
...
@@ -292,17 +292,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -292,17 +292,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
@@ -315,9 +309,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -315,9 +309,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
...
@@ -348,9 +340,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -348,9 +340,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
userptr
[
myrank
][
mylineoffset
+
line
]
=
sum
;
userptr
[
myrank
][
mylineoffset
+
line
]
=
sum
;
}
}
__threadfence_system
();
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
}
// fp16 inplace reduce-scatter kernel
}
// fp16 inplace reduce-scatter kernel
template
<
int
RANKS
>
template
<
int
RANKS
>
...
@@ -368,17 +358,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -368,17 +358,11 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
if
(
blockIdx
.
x
==
0
)
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
volatile
int
*
flag
=
(
volatile
int
*
)
&
(
myptr
[
targetgpu
]);
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
@@ -391,9 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -391,9 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
__syncthreads
();
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
}
}
...
@@ -424,9 +406,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
...
@@ -424,9 +406,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
(
reinterpret_cast
<
int4
*>
(
outbuf
))[(
line
/
rowlines
)
*
skiplines
+
(
line
%
rowlines
)]
=
sum
;
(
reinterpret_cast
<
int4
*>
(
outbuf
))[(
line
/
rowlines
)
*
skiplines
+
(
line
%
rowlines
)]
=
sum
;
}
}
__threadfence_system
();
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
if
(
threadIdx
.
x
==
0
&&
lastSM
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
}
// fp16 reduce-scatter kernel (out of place)
}
// fp16 reduce-scatter kernel (out of place)
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
#if __CUDA_ARCH__ >= 900 && CUDART_VERSION >= 12010
...
@@ -1250,13 +1230,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1250,13 +1230,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
}
}
...
@@ -1292,9 +1268,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1292,9 +1268,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__
int
lastSM
;
__shared__
int
lastSM
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
lastSM
=
1
;
else
else
...
@@ -1302,13 +1276,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1302,13 +1276,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
}
__syncthreads
();
__syncthreads
();
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
__threadfence_system
();
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
flagptr
[
physgpu
]
=
reduce_id
;
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
__threadfence_system
();
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
@@ -1337,13 +1307,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1337,13 +1307,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
targetgpu
=
threadIdx
.
x
*
gpustep
+
firstrank
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
myptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
physgpu
]))
+
flagoffset
;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
reduceidptr
=
myptr
-
NVTE_MAX_OPS
;
// +op;
__threadfence_system
();
reduce_id
=
(
*
reduceidptr
)
+
1
;
reduce_id
=
(
*
reduceidptr
)
+
1
;
__threadfence_system
();
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
flagptr
=
(
reinterpret_cast
<
int
*>
(
commbuff
[
targetgpu
]))
+
flagoffset
;
__threadfence_system
();
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
userptr
[
threadIdx
.
x
]
=
reinterpret_cast
<
int4
*>
(
commbuff
[
targetgpu
+
handleridx
]);
__threadfence_system
();
}
}
__syncthreads
();
__syncthreads
();
localptr
=
userptr
[
myrank
];
localptr
=
userptr
[
myrank
];
...
@@ -1397,9 +1363,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1397,9 +1363,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
__shared__
int
lastSM
;
__shared__
int
lastSM
;
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
const
int
adder
=
blockIdx
.
x
==
0
?
NVTE_MAX_SMS
-
gridDim
.
x
+
1
:
1
;
__threadfence_system
();
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
int
old_val
=
atomicAdd
(
myptr
+
(
NVTE_MAX_NVLINK
*
2
),
adder
);
__threadfence_system
();
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
if
(
old_val
+
adder
==
NVTE_MAX_SMS
*
reduce_id
)
lastSM
=
1
;
lastSM
=
1
;
else
else
...
@@ -1407,13 +1371,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -1407,13 +1371,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
}
__syncthreads
();
__syncthreads
();
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
if
(
lastSM
&&
threadIdx
.
x
<
RANKS
)
{
__threadfence_system
();
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
if
(
threadIdx
.
x
==
0
)
*
reduceidptr
=
reduce_id
;
__threadfence_system
();
flagptr
[
physgpu
]
=
reduce_id
;
flagptr
[
physgpu
]
=
reduce_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
volatile
int
*
flag
=
(
volatile
int
*
)
&
myptr
[
targetgpu
];
__threadfence_system
();
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
while
(
CHECK_IDS
(
*
flag
,
reduce_id
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
if
(
CHECK_TIMEOUT
(
s
,
ub_timeout
))
{
...
@@ -2197,8 +2157,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
...
@@ -2197,8 +2157,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
atomicAdd_system
(
flagptr
,
atomicAdd_system
(
flagptr
,
1
);
// otherwise need local SM sync before sending flag
1
);
// otherwise need local SM sync before sending flag
}
else
{
// 0 bytes and 1 SM only
}
else
{
// 0 bytes and 1 SM only
#ifdef defined(__gfx928__) || defined(__gfx926__) || defined(__gfx906__)
*
flagptr
=
*
flagptr
+
1
;
*
flagptr
=
*
flagptr
+
1
;
__threadfence_system
();
#else
atomicAdd_system
(
flagptr
,
1
);
#endif
}
}
}
}
...
@@ -2210,9 +2173,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe
...
@@ -2210,9 +2173,7 @@ __global__ void kuserbuffers_pushrecv(int myrank, int peer, int nvrank, int nvpe
int
*
ce_start_ptr
,
int
*
ce_end_ptr
)
{
int
*
ce_start_ptr
,
int
*
ce_end_ptr
)
{
const
int
signal_id
=
(
*
recv_id
)
+
adder
;
const
int
signal_id
=
(
*
recv_id
)
+
adder
;
*
recv_id
=
signal_id
;
*
recv_id
=
signal_id
;
__threadfence_system
();
volatile
int
*
flag
=
(
volatile
int
*
)
flagptr
;
volatile
int
*
flag
=
(
volatile
int
*
)
flagptr
;
__threadfence_system
();
if
(
*
flag
>=
signal_id
)
return
;
if
(
*
flag
>=
signal_id
)
return
;
clock_t
s
=
clock64
();
clock_t
s
=
clock64
();
while
(
CHECK_IDS
(
*
flag
,
signal_id
))
{
while
(
CHECK_IDS
(
*
flag
,
signal_id
))
{
...
...
transformer_engine/pytorch/module/base.py
View file @
456a96c8
...
@@ -56,11 +56,7 @@ def get_cublas_workspace_size_bytes() -> None:
...
@@ -56,11 +56,7 @@ def get_cublas_workspace_size_bytes() -> None:
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
"""Return 32 MiB if using hopper, 4 MiB for all other architectures."""
# Add env for control the padding for blaslt
# Add env for control the padding for blaslt
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
nvte_blaslt_nopad
=
int
(
os
.
environ
.
get
(
"NVTE_BLASLT_NOPAD"
,
0
))
return
134_217_728
if
(
nvte_blaslt_nopad
):
return
536_870_912
else
:
return
1_073_741_824
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
if
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
major
>=
9
:
return
33_554_432
return
33_554_432
return
4_194_304
return
4_194_304
...
...
transformer_engine/pytorch/utils.py
View file @
456a96c8
...
@@ -253,7 +253,7 @@ if IS_HIP_EXTENSION:
...
@@ -253,7 +253,7 @@ if IS_HIP_EXTENSION:
import
re
import
re
return
(
re
.
search
(
'K100_AI'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
return
(
re
.
search
(
'K100_AI'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
def
is_BW
3000
():
def
is_BW
():
"""check whether this machine is BW"""
"""check whether this machine is BW"""
import
re
import
re
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
return
(
re
.
search
(
'BW'
,
torch
.
cuda
.
get_device_name
(
torch
.
cuda
.
current_device
()))
is
not
None
)
...
@@ -264,7 +264,7 @@ def is_bf16_compatible() -> None:
...
@@ -264,7 +264,7 @@ def is_bf16_compatible() -> None:
"""
"""
if
IS_HIP_EXTENSION
:
if
IS_HIP_EXTENSION
:
# only MI200 and MI300 machines support bf16
# only MI200 and MI300 machines support bf16
if
get_device_compute_capability
()
==
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
3000
():
if
get_device_compute_capability
()
==
(
9
,
4
)
or
is_mi200
()
or
is_K100_AI
()
or
is_BW
():
return
True
return
True
else
:
else
:
return
False
return
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment