"docker/vscode:/vscode.git/clone" did not exist on "fbd3199c0209b7ace266d798a6a8e95813d6cde0"
Unverified Commit 240ea97b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

only add 1bit adam reqs if mpi is installed, update cond build for cpu-adam (#400)

parent b29229bf
...@@ -10,8 +10,10 @@ The wheel will be located at: dist/*.whl ...@@ -10,8 +10,10 @@ The wheel will be located at: dist/*.whl
import os import os
import torch import torch
import shutil
import subprocess import subprocess
import warnings import warnings
import cpufeature
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
...@@ -27,47 +29,53 @@ install_requires = fetch_requirements('requirements/requirements.txt') ...@@ -27,47 +29,53 @@ install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt') dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt') sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt') # If MPI is available add 1bit-adam requirements
if torch.cuda.is_available(): if torch.cuda.is_available():
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}") if shutil.which('ompi_info') or shutil.which('mpiname'):
install_requires += onebit_adam_requires onebit_adam_requires = fetch_requirements(
'requirements/requirements-1bit-adam.txt')
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
# Constants for each op # Constants for each op
LAMB = "lamb" LAMB = "lamb"
TRANSFORMER = "transformer" TRANSFORMER = "transformer"
SPARSE_ATTN = "sparse-attn" SPARSE_ATTN = "sparse-attn"
ADAM = "cpu-adam" CPU_ADAM = "cpu-adam"
# Build environment variables for custom builds # Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1 DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10 DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100 DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_ADAM_MASK = 1000 DS_BUILD_CPU_ADAM_MASK = 1000
DS_BUILD_AVX512_MASK = 10000 DS_BUILD_AVX512_MASK = 10000
# Allow for build_cuda to turn on or off all ops # Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_ADAM_MASK | DS_BUILD_AVX512_MASK DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
# Set default of each op based on if build_cuda is set # Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_ADAM = int(os.environ.get('DS_BUILD_ADAM', OP_DEFAULT)) * DS_BUILD_ADAM_MASK DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM',
OP_DEFAULT)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER', DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN', DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
0)) * DS_BUILD_SPARSE_ATTN_MASK OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get('DS_BUILD_AVX512', 0)) * DS_BUILD_AVX512_MASK DS_BUILD_AVX512 = int(os.environ.get(
'DS_BUILD_AVX512',
cpufeature.CPUFeature['AVX512f'])) * DS_BUILD_AVX512_MASK
# Final effective mask is the bitwise OR of each op # Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
| DS_BUILD_ADAM) | DS_BUILD_CPU_ADAM)
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, ADAM], False) install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False)
if BUILD_MASK & DS_BUILD_LAMB: if BUILD_MASK & DS_BUILD_LAMB:
install_ops[LAMB] = True install_ops[LAMB] = True
if BUILD_MASK & DS_BUILD_ADAM: if BUILD_MASK & DS_BUILD_CPU_ADAM:
install_ops[ADAM] = True install_ops[CPU_ADAM] = True
if BUILD_MASK & DS_BUILD_TRANSFORMER: if BUILD_MASK & DS_BUILD_TRANSFORMER:
install_ops[TRANSFORMER] = True install_ops[TRANSFORMER] = True
if BUILD_MASK & DS_BUILD_SPARSE_ATTN: if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
...@@ -103,9 +111,7 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -103,9 +111,7 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
import cpufeature
cpu_info = cpufeature.CPUFeature cpu_info = cpufeature.CPUFeature
SIMD_WIDTH = '' SIMD_WIDTH = ''
if cpu_info['AVX512f'] and DS_BUILD_AVX512: if cpu_info['AVX512f'] and DS_BUILD_AVX512:
SIMD_WIDTH = '-D__AVX512__' SIMD_WIDTH = '-D__AVX512__'
...@@ -133,7 +139,7 @@ if BUILD_MASK & DS_BUILD_LAMB: ...@@ -133,7 +139,7 @@ if BUILD_MASK & DS_BUILD_LAMB:
})) }))
## Adam ## ## Adam ##
if BUILD_MASK & DS_BUILD_ADAM: if BUILD_MASK & DS_BUILD_CPU_ADAM:
ext_modules.append( ext_modules.append(
CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op', CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op',
sources=[ sources=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment