Unverified Commit 240ea97b authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

only add 1bit adam reqs if mpi is installed, update cond build for cpu-adam (#400)

parent b29229bf
...@@ -10,8 +10,10 @@ The wheel will be located at: dist/*.whl ...@@ -10,8 +10,10 @@ The wheel will be located at: dist/*.whl
import os import os
import torch import torch
import shutil
import subprocess import subprocess
import warnings import warnings
import cpufeature
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension from torch.utils.cpp_extension import CUDAExtension, BuildExtension, CppExtension
...@@ -27,47 +29,53 @@ install_requires = fetch_requirements('requirements/requirements.txt') ...@@ -27,47 +29,53 @@ install_requires = fetch_requirements('requirements/requirements.txt')
dev_requires = fetch_requirements('requirements/requirements-dev.txt') dev_requires = fetch_requirements('requirements/requirements-dev.txt')
sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt') sparse_attn_requires = fetch_requirements('requirements/requirements-sparse-attn.txt')
onebit_adam_requires = fetch_requirements('requirements/requirements-1bit-adam.txt') # If MPI is available add 1bit-adam requirements
if torch.cuda.is_available(): if torch.cuda.is_available():
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}") if shutil.which('ompi_info') or shutil.which('mpiname'):
install_requires += onebit_adam_requires onebit_adam_requires = fetch_requirements(
'requirements/requirements-1bit-adam.txt')
onebit_adam_requires.append(f"cupy-cuda{torch.version.cuda.replace('.','')[:3]}")
install_requires += onebit_adam_requires
# Constants for each op # Constants for each op
LAMB = "lamb" LAMB = "lamb"
TRANSFORMER = "transformer" TRANSFORMER = "transformer"
SPARSE_ATTN = "sparse-attn" SPARSE_ATTN = "sparse-attn"
ADAM = "cpu-adam" CPU_ADAM = "cpu-adam"
# Build environment variables for custom builds # Build environment variables for custom builds
DS_BUILD_LAMB_MASK = 1 DS_BUILD_LAMB_MASK = 1
DS_BUILD_TRANSFORMER_MASK = 10 DS_BUILD_TRANSFORMER_MASK = 10
DS_BUILD_SPARSE_ATTN_MASK = 100 DS_BUILD_SPARSE_ATTN_MASK = 100
DS_BUILD_ADAM_MASK = 1000 DS_BUILD_CPU_ADAM_MASK = 1000
DS_BUILD_AVX512_MASK = 10000 DS_BUILD_AVX512_MASK = 10000
# Allow for build_cuda to turn on or off all ops # Allow for build_cuda to turn on or off all ops
DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_ADAM_MASK | DS_BUILD_AVX512_MASK DS_BUILD_ALL_OPS = DS_BUILD_LAMB_MASK | DS_BUILD_TRANSFORMER_MASK | DS_BUILD_SPARSE_ATTN_MASK | DS_BUILD_CPU_ADAM_MASK | DS_BUILD_AVX512_MASK
DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS DS_BUILD_CUDA = int(os.environ.get('DS_BUILD_CUDA', 1)) * DS_BUILD_ALL_OPS
# Set default of each op based on if build_cuda is set # Set default of each op based on if build_cuda is set
OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS OP_DEFAULT = DS_BUILD_CUDA == DS_BUILD_ALL_OPS
DS_BUILD_ADAM = int(os.environ.get('DS_BUILD_ADAM', OP_DEFAULT)) * DS_BUILD_ADAM_MASK DS_BUILD_CPU_ADAM = int(os.environ.get('DS_BUILD_CPU_ADAM',
OP_DEFAULT)) * DS_BUILD_CPU_ADAM_MASK
DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK DS_BUILD_LAMB = int(os.environ.get('DS_BUILD_LAMB', OP_DEFAULT)) * DS_BUILD_LAMB_MASK
DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER', DS_BUILD_TRANSFORMER = int(os.environ.get('DS_BUILD_TRANSFORMER',
OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK OP_DEFAULT)) * DS_BUILD_TRANSFORMER_MASK
DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN', DS_BUILD_SPARSE_ATTN = int(os.environ.get('DS_BUILD_SPARSE_ATTN',
0)) * DS_BUILD_SPARSE_ATTN_MASK OP_DEFAULT)) * DS_BUILD_SPARSE_ATTN_MASK
DS_BUILD_AVX512 = int(os.environ.get('DS_BUILD_AVX512', 0)) * DS_BUILD_AVX512_MASK DS_BUILD_AVX512 = int(os.environ.get(
'DS_BUILD_AVX512',
cpufeature.CPUFeature['AVX512f'])) * DS_BUILD_AVX512_MASK
# Final effective mask is the bitwise OR of each op # Final effective mask is the bitwise OR of each op
BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN BUILD_MASK = (DS_BUILD_LAMB | DS_BUILD_TRANSFORMER | DS_BUILD_SPARSE_ATTN
| DS_BUILD_ADAM) | DS_BUILD_CPU_ADAM)
install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, ADAM], False) install_ops = dict.fromkeys([LAMB, TRANSFORMER, SPARSE_ATTN, CPU_ADAM], False)
if BUILD_MASK & DS_BUILD_LAMB: if BUILD_MASK & DS_BUILD_LAMB:
install_ops[LAMB] = True install_ops[LAMB] = True
if BUILD_MASK & DS_BUILD_ADAM: if BUILD_MASK & DS_BUILD_CPU_ADAM:
install_ops[ADAM] = True install_ops[CPU_ADAM] = True
if BUILD_MASK & DS_BUILD_TRANSFORMER: if BUILD_MASK & DS_BUILD_TRANSFORMER:
install_ops[TRANSFORMER] = True install_ops[TRANSFORMER] = True
if BUILD_MASK & DS_BUILD_SPARSE_ATTN: if BUILD_MASK & DS_BUILD_SPARSE_ATTN:
...@@ -103,9 +111,7 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): ...@@ -103,9 +111,7 @@ if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4):
version_ge_1_5 = ['-DVERSION_GE_1_5'] version_ge_1_5 = ['-DVERSION_GE_1_5']
version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5 version_dependent_macros = version_ge_1_1 + version_ge_1_3 + version_ge_1_5
import cpufeature
cpu_info = cpufeature.CPUFeature cpu_info = cpufeature.CPUFeature
SIMD_WIDTH = '' SIMD_WIDTH = ''
if cpu_info['AVX512f'] and DS_BUILD_AVX512: if cpu_info['AVX512f'] and DS_BUILD_AVX512:
SIMD_WIDTH = '-D__AVX512__' SIMD_WIDTH = '-D__AVX512__'
...@@ -133,7 +139,7 @@ if BUILD_MASK & DS_BUILD_LAMB: ...@@ -133,7 +139,7 @@ if BUILD_MASK & DS_BUILD_LAMB:
})) }))
## Adam ## ## Adam ##
if BUILD_MASK & DS_BUILD_ADAM: if BUILD_MASK & DS_BUILD_CPU_ADAM:
ext_modules.append( ext_modules.append(
CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op', CUDAExtension(name='deepspeed.ops.adam.cpu_adam_op',
sources=[ sources=[
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment