Unverified Commit 643fdd2b authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #27 from PanZezhong1725/issue/1_optimize_matmul_test

issue1: New Optimized Matmul Test
parents f0af9e13 e17f5662
......@@ -8,3 +8,17 @@ class InfiniDeviceEnum:
ILUVATAR = 6
KUNLUN = 7
SUGON = 8
# Mapping that maps InfiniDeviceEnum to torch device string
infiniDeviceEnum_str_map = {
InfiniDeviceEnum.CPU: "cpu",
InfiniDeviceEnum.NVIDIA: "cuda",
InfiniDeviceEnum.CAMBRICON: "mlu",
InfiniDeviceEnum.ASCEND: "npu",
InfiniDeviceEnum.METAX: "cuda",
InfiniDeviceEnum.MOORE: "musa",
InfiniDeviceEnum.ILUVATAR: "cuda",
InfiniDeviceEnum.KUNLUN: "cuda",
InfiniDeviceEnum.SUGON: "cuda",
}
import ctypes
from .datatypes import *
from .devices import *
from typing import Sequence
from .liboperators import infiniopTensorDescriptor_t, CTensor, infiniopHandle_t
......@@ -46,12 +48,15 @@ def to_tensor(tensor, lib):
# Create Tensor
return CTensor(tensor_desc, data_ptr)
def create_workspace(size, torch_device):
print(f" - Workspace Size : {size}")
if size == 0:
return None
import torch
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)
def create_handle(lib, device, id=0):
handle = infiniopHandle_t()
check_error(lib.infiniopCreateHandle(ctypes.byref(handle), device, id))
......@@ -106,3 +111,276 @@ def rearrange_tensor(tensor, new_strides):
new_tensor.set_(new_tensor.untyped_storage(), offset, shape, tuple(new_strides))
return new_tensor
def rearrange_if_needed(tensor, stride):
"""
Rearrange a PyTorch tensor if the given stride is not None.
"""
return rearrange_tensor(tensor, stride) if stride is not None else tensor
def get_args():
import argparse
parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument(
"--profile",
action="store_true",
help="Whether profile tests",
)
parser.add_argument(
"--num_prerun",
type=lambda x: max(0, int(x)),
default=10,
help="Set the number of pre-runs before profiling. Default is 10. Must be a non-negative integer.",
)
parser.add_argument(
"--num_iterations",
type=lambda x: max(0, int(x)),
default=1000,
help="Set the number of iterations for profiling. Default is 1000. Must be a non-negative integer.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Whether to turn on debug mode. If turned on, it will display detailed information about the tensors and discrepancies.",
)
parser.add_argument(
"--cpu",
action="store_true",
help="Run CPU test",
)
parser.add_argument(
"--nvidia",
action="store_true",
help="Run NVIDIA GPU test",
)
parser.add_argument(
"--cambricon",
action="store_true",
help="Run Cambricon MLU test",
)
parser.add_argument(
"--ascend",
action="store_true",
help="Run ASCEND NPU test",
)
return parser.parse_args()
def synchronize_device(torch_device):
import torch
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debugging function to compare two tensors (actual and desired) and print discrepancies.
Arguments:
----------
- actual : The tensor containing the actual computed values.
- desired : The tensor containing the expected values that `actual` should be compared to.
- atol : optional (default=0)
The absolute tolerance for the comparison.
- rtol : optional (default=1e-2)
The relative tolerance for the comparison.
- equal_nan : bool, optional (default=False)
If True, `NaN` values in `actual` and `desired` will be considered equal.
- verbose : bool, optional (default=True)
If True, the function will print detailed information about any discrepancies between the tensors.
"""
import numpy as np
print_discrepancy(actual, desired, atol, rtol, verbose)
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)
def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True):
"""
Debugging function to compare two sequences of values (actual and desired) pair by pair, results
are linked by the given logical condition, and prints discrepancies
Arguments:
----------
- actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values.
- desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against.
- condition (str): A string specifying the condition for passing the test. It must be either:
- 'or': Test passes if any pair of actual and desired values satisfies the tolerance criteria.
- 'and': Test passes if all pairs of actual and desired values satisfy the tolerance criteria.
- atol (float, optional): Absolute tolerance. Default is 0.
- rtol (float, optional): Relative tolerance. Default is 1e-2.
- equal_nan (bool, optional): If True, NaN values in both actual and desired are considered equal. Default is False.
- verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True.
Raises:
----------
- AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`.
- ValueError: If the length of `actual_vals` and `desired_vals` do not match.
- AssertionError: If the specified `condition` is not 'or' or 'and'.
"""
assert len(actual_vals) == len(desired_vals), "Invalid Length"
assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'"
import numpy as np
passed = False if condition == "or" else True
for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)):
print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}")
indices = print_discrepancy(actual, desired, atol, rtol, verbose)
if condition == "or":
if not passed and len(indices) == 0:
passed = True
elif condition == "and":
if passed and len(indices) != 0:
passed = False
print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m")
np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True)
assert passed, "\033[31mThe condition has not been satisfied\033[0m"
def print_discrepancy(
actual, expected, atol=0, rtol=1e-3, verbose=True
):
if actual.shape != expected.shape:
raise ValueError("Tensors must have the same shape to compare.")
import torch
import sys
is_terminal = sys.stdout.isatty()
# Calculate the difference mask based on atol and rtol
diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected))
diff_indices = torch.nonzero(diff_mask, as_tuple=False)
delta = actual - expected
# Display format: widths for columns
col_width = [18, 20, 20, 20]
decimal_places = [0, 12, 12, 12]
total_width = sum(col_width) + sum(decimal_places)
def add_color(text, color_code):
if is_terminal:
return f"\033[{color_code}m{text}\033[0m"
else:
return text
if verbose:
for idx in diff_indices:
index_tuple = tuple(idx.tolist())
actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}"
expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}"
delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}"
print(
f" > Index: {str(index_tuple):<{col_width[0]}}"
f"actual: {add_color(actual_str, 31)}"
f"expect: {add_color(expected_str, 32)}"
f"delta: {add_color(delta_str, 33)}"
)
print(add_color(" INFO:", 35))
print(f" - Actual dtype: {actual.dtype}")
print(f" - Desired dtype: {expected.dtype}")
print(f" - Atol: {atol}")
print(f" - Rtol: {rtol}")
print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)")
print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}")
print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}")
print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}")
print("-" * total_width + "\n")
return diff_indices
def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3):
"""
Returns the atol and rtol for a given tensor data type in the tolerance_map.
If the given data type is not found, it returns the provided default tolerance values.
"""
return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values()
def timed_op(func, num_iterations, device):
import time
""" Function for timing operations with synchronization. """
synchronize_device(device)
start = time.time()
for _ in range(num_iterations):
func()
synchronize_device(device)
return (time.time() - start) / num_iterations
def profile_operation(desc, func, torch_device, NUM_PRERUN, NUM_ITERATIONS):
"""
Unified profiling workflow that is used to profile the execution time of a given function.
It first performs a number of warmup runs, then performs timed execution and
prints the average execution time.
Arguments:
----------
- desc (str): Description of the operation, used for output display.
- func (callable): The operation function to be profiled.
- torch_device (str): The device on which the operation runs, provided for timed execution.
- NUM_PRERUN (int): The number of warmup runs.
- NUM_ITERATIONS (int): The number of timed execution iterations, used to calculate the average execution time.
"""
# Warmup runs
for _ in range(NUM_PRERUN):
func()
# Timed execution
elapsed = timed_op(lambda: func(), NUM_ITERATIONS, torch_device)
print(f" {desc} time: {elapsed * 1000 :6f} ms")
def test_operator(lib, device, test_func, test_cases, tensor_dtypes):
"""
Testing a specified operator on the given device with the given test function, test cases, and tensor data types.
Arguments:
----------
- lib (ctypes.CDLL): The library object containing the operator implementations.
- device (InfiniDeviceEnum): The device on which the operator should be tested. See device.py.
- test_func (function): The test function to be executed for each test case.
- test_cases (list of tuples): A list of test cases, where each test case is a tuple of parameters
to be passed to `test_func`.
- tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test.
"""
handle = create_handle(lib, device)
try:
for test_case in test_cases:
for tensor_dtype in tensor_dtypes:
test_func(lib, handle, infiniDeviceEnum_str_map[device], *test_case, tensor_dtype)
finally:
destroy_handle(lib, handle)
def get_test_devices(args):
"""
Using the given parsed Namespace to determine the devices to be tested.
Argument:
- args: the parsed Namespace object.
Return:
- devices_to_test: the devices that will be tested. Default is CPU.
"""
devices_to_test = []
if args.cpu: devices_to_test.append(InfiniDeviceEnum.CPU)
if args.nvidia: devices_to_test.append(InfiniDeviceEnum.NVIDIA)
if args.cambricon:
import torch_mlu
devices_to_test.append(InfiniDeviceEnum.CAMBRICON)
if args.ascend:
import torch_npu
devices_to_test.append(InfiniDeviceEnum.ASCEND)
if not devices_to_test:
devices_to_test = [InfiniDeviceEnum.CPU]
return devices_to_test
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
import torch
import ctypes
import sys
import os
import time
sys.path.append("..")
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
open_lib,
to_tensor,
CTensor,
InfiniDeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
rearrange_tensor,
create_workspace,
infiniopHandle_t, infiniopTensorDescriptor_t, open_lib, to_tensor, get_test_devices,
check_error, rearrange_if_needed, create_workspace, test_operator, get_args,
debug, get_tolerance, profile_operation,
)
from test_utils import get_args, synchronize_device
import torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
(1.0/8.0, 0.0, (4, 8*6, 64), (4, 64, 6), (4, 8*6, 6), None, None, None),
]
# Data types used for testing
_TENSOR_DTYPES = [torch.float16, torch.float32]
# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {'atol': 0, 'rtol': 1e-2},
torch.float32: {'atol': 0, 'rtol': 1e-3},
}
DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000
# ==============================================================================
# Definitions
# ==============================================================================
class MatmulDescriptor(Structure):
_fields_ = [("device", c_int32)]
infiniopMatmulDescriptor_t = POINTER(MatmulDescriptor)
# PyTorch implementation for matrix multiplication
def matmul(_c, beta, _a, _b, alpha):
a = _a.clone()
b = _b.clone()
c = _c.clone()
input_dtype = c.dtype
ans = (
alpha * torch.matmul(a.to(torch.float32), b.to(torch.float32)).to(input_dtype)
+ beta * c
)
return ans
a, b, c = _a.clone(), _b.clone(), _c.clone()
result_dtype = c.dtype
fp32_result = torch.matmul(a.to(torch.float32), b.to(torch.float32))
return alpha * fp32_result.to(result_dtype) + beta * c
# The argument list should be (lib, handle, torch_device, <param list>, dtype)
# The <param list> should keep the same order as the one specified in _TEST_CASES
def test(
lib,
handle,
......@@ -60,26 +72,22 @@ def test(
dtype=torch.float16,
):
print(
f"Testing Matmul on {torch_device} with a_shape:{a_shape} b_shape:{b_shape} c_shape:{c_shape}"
f" a_stride:{a_stride} b_stride:{b_stride} c_stride:{c_stride} dtype:{dtype}"
f"Testing Matmul on {torch_device} with alpha:{alpha}, beta:{beta},"
f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape},"
f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{dtype}"
)
# Initialize tensors
a = torch.rand(a_shape, dtype=dtype).to(torch_device)
b = torch.rand(b_shape, dtype=dtype).to(torch_device)
c = torch.ones(c_shape, dtype=dtype).to(torch_device)
# Compute the PyTorch reference result
ans = matmul(c, beta, a, b, alpha)
if a_stride is not None:
a = rearrange_tensor(a, a_stride)
if b_stride is not None:
b = rearrange_tensor(b, b_stride)
if c_stride is not None:
c = rearrange_tensor(c, c_stride)
a, b, c = [rearrange_if_needed(tensor, stride) for tensor, stride in zip([a, b, c], [a_stride, b_stride, c_stride])]
a_tensor, b_tensor, c_tensor = [to_tensor(tensor, lib) for tensor in [a, b, c]]
a_tensor = to_tensor(a, lib)
b_tensor = to_tensor(b, lib)
c_tensor = to_tensor(c, lib)
descriptor = infiniopMatmulDescriptor_t()
check_error(
lib.infiniopCreateMatmulDescriptor(
......@@ -92,20 +100,19 @@ def test(
)
# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
a_tensor.descriptor.contents.invalidate()
b_tensor.descriptor.contents.invalidate()
c_tensor.descriptor.contents.invalidate()
for tensor in [a_tensor, b_tensor, c_tensor]:
tensor.descriptor.contents.invalidate()
# Get workspace size and create workspace
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size))
)
check_error(lib.infiniopGetMatmulWorkspaceSize(descriptor, ctypes.byref(workspace_size)))
workspace = create_workspace(workspace_size.value, a.device)
check_error(
lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
# Execute infiniop matmul operator
def lib_matmul():
check_error(lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
......@@ -113,201 +120,27 @@ def test(
alpha,
beta,
None,
)
)
))
lib_matmul()
assert torch.allclose(c, ans, atol=0, rtol=1e-2)
# Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(c, ans, atol=atol, rtol=rtol)
assert torch.allclose(c, ans, atol=atol, rtol=rtol)
# Profiling workflow
if PROFILE:
for i in range(NUM_PRERUN):
_ = matmul(c, beta, a, b, alpha)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
_ = matmul(c, beta, a, b, alpha)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" pytorch time: {elapsed * 1000 :6f} ms")
for i in range(NUM_PRERUN):
check_error(
lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None,
)
)
synchronize_device(torch_device)
start_time = time.time()
for i in range(NUM_ITERATIONS):
check_error(
lib.infiniopMatmul(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
c_tensor.data,
a_tensor.data,
b_tensor.data,
None,
)
)
synchronize_device(torch_device)
elapsed = (time.time() - start_time) / NUM_ITERATIONS
print(f" lib time: {elapsed * 1000 :6f} ms")
profile_operation("PyTorch", lambda: matmul(c, beta, a, b, alpha), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_matmul(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
check_error(lib.infiniopDestroyMatmulDescriptor(descriptor))
def test_cpu(lib, test_cases):
device = InfiniDeviceEnum.CPU
handle = create_handle(lib, device)
for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"cpu",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)
destroy_handle(lib, handle)
def test_nvidia(lib, test_cases):
device = InfiniDeviceEnum.NVIDIA
handle = create_handle(lib, device)
for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"cuda",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)
destroy_handle(lib, handle)
def test_cambricon(lib, test_cases):
import torch_mlu
device = InfiniDeviceEnum.CAMBRICON
handle = create_handle(lib, device)
for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"mlu",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)
destroy_handle(lib, handle)
def test_ascend(lib, test_cases):
import torch_npu
device = InfiniDeviceEnum.ASCEND
handle = create_handle(lib, device)
for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"npu",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)
destroy_handle(lib, handle)
# ==============================================================================
# Main Execution
# ==============================================================================
if __name__ == "__main__":
test_cases = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float16),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None, torch.float32),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float16),
(1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None, torch.float32),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float16),
(1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1), torch.float32),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float16),
(1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1), torch.float32),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float16),
(1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None, torch.float32),
]
args = get_args()
lib = open_lib()
......@@ -344,16 +177,14 @@ if __name__ == "__main__":
infiniopMatmulDescriptor_t,
]
if args.profile:
PROFILE = True
if args.cpu:
test_cpu(lib, test_cases)
if args.nvidia:
test_nvidia(lib, test_cases)
if args.cambricon:
test_cambricon(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.nvidia or args.cambricon or args.ascend):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations
# Execute tests
for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)
print("\033[92mTest passed!\033[0m")
def get_args():
import argparse
parser = argparse.ArgumentParser(description="Test Operator")
parser.add_argument(
"--profile",
action="store_true",
help="Whether profile tests",
)
parser.add_argument(
"--cpu",
action="store_true",
help="Run CPU test",
)
parser.add_argument(
"--nvidia",
action="store_true",
help="Run NVIDIA GPU test",
)
parser.add_argument(
"--cambricon",
action="store_true",
help="Run Cambricon MLU test",
)
parser.add_argument(
"--ascend",
action="store_true",
help="Run ASCEND NPU test",
)
return parser.parse_args()
def synchronize_device(torch_device):
import torch
if torch_device == "cuda":
torch.cuda.synchronize()
elif torch_device == "npu":
torch.npu.synchronize()
elif torch_device == "mlu":
torch.mlu.synchronize()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment