Unverified Commit 7dbeba30 authored by Jeff Rasley's avatar Jeff Rasley Committed by GitHub
Browse files

Add deepspeed env file to pass custom env values (#117)

* add support for deepspeed env file to pass custom env values

* simplify deepspeed config example
parent 7e3509bb
...@@ -286,27 +286,39 @@ doc](https://microsoft.github.io/DeepSpeed/docs/htmlfiles/api/full/index.html). ...@@ -286,27 +286,39 @@ doc](https://microsoft.github.io/DeepSpeed/docs/htmlfiles/api/full/index.html).
{ {
"train_batch_size": 8, "train_batch_size": 8,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": true,
"disable_allgather": true,
"optimizer": { "optimizer": {
"type": "Adam", "type": "Adam",
"params": { "params": {
"lr": 0.00015, "lr": 0.00015
"max_grad_norm": 1.0
} }
}, },
"fp16": { "fp16": {
"enabled": true, "enabled": true
"loss_scale": 0, },
"loss_scale_window": 1000, "zero_optimization": true
"hysteresis": 2,
"min_loss_scale": 1
}
} }
``` ```
## Multi-Node Environment Variables
When training across multiple nodes we have found it useful to support
propagating user-defined environment variables. By default DeepSpeed will
propagate all NCCL and PYTHON related environment variables that are set. If
you would like to propagate additional variables you can specify them in a
dot-file named `.deepspeed_env` that contains a new-line separated list of
`VAR=VAL` entries. The DeepSpeed launcher will look in the local path you are
executing from and also in your home directory (`~/`).
As a concrete example, some clusters require special NCCL variables to set
prior to training. The user can simply add these variables to a
`.deepspeed_env` file in their home directory that looks like this:
```
NCCL_IB_DISABLE=1
NCCL_SOCKET_IFNAME=eth0
```
DeepSpeed will then make sure that these environment variables are set when
launching each process on every node across their training job.
# Launching DeepSpeed Training # Launching DeepSpeed Training
DeepSpeed installs the entry point `deepspeed` to launch distributed training. DeepSpeed installs the entry point `deepspeed` to launch distributed training.
We illustrate an example usage of DeepSpeed with the following assumptions: We illustrate an example usage of DeepSpeed with the following assumptions:
......
...@@ -17,6 +17,8 @@ from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT ...@@ -17,6 +17,8 @@ from deepspeed.pt.deepspeed_constants import TORCH_DISTRIBUTED_DEFAULT_PORT
DLTS_HOSTFILE = "/job/hostfile" DLTS_HOSTFILE = "/job/hostfile"
EXPORT_ENVS = ["NCCL", "PYTHONPATH"] EXPORT_ENVS = ["NCCL", "PYTHONPATH"]
DEEPSPEED_ENVIRONMENT_NAME = ".deepspeed_env"
DEEPSPEED_ENVIRONMENT_PATHS = [os.path.expanduser("~"), '.']
def parse_args(args=None): def parse_args(args=None):
...@@ -317,6 +319,13 @@ def main(args=None): ...@@ -317,6 +319,13 @@ def main(args=None):
if any(map(lambda name: name in var, EXPORT_ENVS)): if any(map(lambda name: name in var, EXPORT_ENVS)):
exports += "export {}={}; ".format(var, env[var]) exports += "export {}={}; ".format(var, env[var])
for environ_path in DEEPSPEED_ENVIRONMENT_PATHS:
environ_file = os.path.join(environ_path, DEEPSPEED_ENVIRONMENT_NAME)
if os.path.isfile(environ_file):
with open(environ_file, 'r') as fd:
for var in fd.readlines():
exports += "export {}; ".format(var.strip())
deepspeed_launch = [ deepspeed_launch = [
exports, exports,
"cd {};".format(curr_path), "cd {};".format(curr_path),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment