"qa/vscode:/vscode.git/clone" did not exist on "2f8739f56a09b70b7c986e57231fd290e8cf0fe0"
Commit aeceeac0 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.10' into release_v2.10

parents bd05b0dc 284d3f6f
......@@ -392,10 +392,10 @@ if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE
_CURAND_LIB_CTYPES = _load_curand()
_CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas")
_CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime")
_TE_LIB_CTYPES = _load_core_library()
# Needed to find the correct headers for NVRTC kernels.
if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir():
os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir()
except OSError:
pass
_TE_LIB_CTYPES = _load_core_library()
......@@ -605,7 +605,7 @@ class BatchedLinear(TransformerEngineBaseModule):
weight_tensors_fp8 = [None] * int(self.num_gemms)
from ..cpu_offload import CPUOffloadEnabled
from ..cpu_offload_v1 import CPUOffloadEnabled
if torch.is_grad_enabled():
linear_fn = _BatchLinear.apply
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment