Unverified Commit 7a3199b1 authored by anj-s's avatar anj-s Committed by GitHub
Browse files

[refactor] Use logging in place of print statements, remove unused functions...

[refactor] Use logging in place of print statements, remove unused functions and other minor refactoring changes. (#461)

* fix pipe logging and other cleanups

* more log/debug changes
parent 428110b8
...@@ -537,7 +537,7 @@ if __name__ == "__main__": ...@@ -537,7 +537,7 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# bench_multi_process(args, all_at_once=True) # bench_multi_process(args, all_at_once=True)
if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ: if args.no_mpi or "OMPI_COMM_WORLD_RANK" not in os.environ:
print(f"Can't run benchmark") print("Can't run benchmark")
sys.exit(1) sys.exit(1)
else: else:
......
...@@ -7,7 +7,6 @@ import gc ...@@ -7,7 +7,6 @@ import gc
import logging import logging
import math import math
import operator import operator
import pprint
import time import time
from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders from datasets.wikitext2_data import get_real_dataloaders as get_real_wikitext2_dataloaders
...@@ -97,66 +96,6 @@ def get_tensors_by_size_bucket(): ...@@ -97,66 +96,6 @@ def get_tensors_by_size_bucket():
return size_buckets return size_buckets
def dump_size_buckets(size_buckets, prefix=""):
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(prefix + f"{key} : {value}, {this}")
print(prefix + f"total = {total}")
last_size_buckets = None
once = True
def safe_rank():
try:
return torch.distributed.get_rank()
except AssertionError:
return 0
def check_size_buckets():
global last_size_buckets
global once
size_buckets = get_tensors_by_size_bucket()
if last_size_buckets is not None:
if size_buckets != last_size_buckets:
print(f"difference is oustanding tensors: {safe-rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
if once:
print(f"dumping buckets for: {safe_rank()}")
dump_size_buckets(last_size_buckets, "old: ")
dump_size_buckets(size_buckets, "new: ")
once = False
else:
print(f"size buckets none on {safe_rank()}")
last_size_buckets = size_buckets
def dump_cuda_tensors():
print(f"dumping cuda tensors...")
for obj in gc.get_objects():
if not isinstance(obj, torch.Tensor):
continue
if obj.device.type == "cuda":
size_buckets[(*obj.size(),) + (obj.element_size(),)] += 1
print(f"outstanding cuda tensors:")
total = 0
for key, value in size_buckets.items():
this = reduce(operator.mul, key) * value
total += this
print(f"{key} : {value}, {this}")
print(f"total size = {total}")
pprint.pprint(torch.cuda.memory_stats())
def log_number_of_parameters(model): def log_number_of_parameters(model):
num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
...@@ -165,15 +104,15 @@ def log_number_of_parameters(model): ...@@ -165,15 +104,15 @@ def log_number_of_parameters(model):
if torch.cuda.is_available(): if torch.cuda.is_available():
total = total.cuda() total = total.cuda()
torch.distributed.all_reduce(total, group=model.group) torch.distributed.all_reduce(total, group=model.group)
logging.info( logging.debug(
f"training model, #params = {num_params}, group: {model.group.rank()}, grank:" f"training model, #params = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}" f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
) )
torch.distributed.barrier() torch.distributed.barrier()
if model.group.rank() == 0: if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}") logging.debug(f"total #prams = {total.item()}")
else: else:
logging.info(f"training model, #params = {num_params}") logging.debug(f"training model, #params = {num_params}")
def get_device(model, index): def get_device(model, index):
...@@ -292,7 +231,7 @@ def train(model_config, model, benchmark_config, model_specs, args): ...@@ -292,7 +231,7 @@ def train(model_config, model, benchmark_config, model_specs, args):
cur_loss = total_loss / log_interval cur_loss = total_loss / log_interval
elapsed = time.time() - start_time elapsed = time.time() - start_time
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
print( logging.debug(
"| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format( "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}".format(
i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss) i, total_tokens_per_log_interval / elapsed, cur_loss, math.exp(cur_loss)
) )
...@@ -341,7 +280,9 @@ def get_number_of_words(data): ...@@ -341,7 +280,9 @@ def get_number_of_words(data):
def verify_peak_memory(rank, golden_config, std_dev): def verify_peak_memory(rank, golden_config, std_dev):
print("Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])) logging.debug(
"Peak allocated bytes on cuda:0: {:1d}".format(torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"])
)
current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"] current_device_usage = torch.cuda.memory_stats(rank)["allocated_bytes.all.peak"]
golden_ref = golden_config["peak_mem_usage"][rank] golden_ref = golden_config["peak_mem_usage"][rank]
if not current_device_usage < golden_ref * std_dev: if not current_device_usage < golden_ref * std_dev:
...@@ -358,7 +299,7 @@ def verify_lm_run(wps, golden_config, args): ...@@ -358,7 +299,7 @@ def verify_lm_run(wps, golden_config, args):
if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1: if not args.multiprocess or dist.get_rank() == dist.get_world_size() - 1:
# Assert that words per second is within 3 standard deviations of the average # Assert that words per second is within 3 standard deviations of the average
# of five golden runs # of five golden runs
print("Throughput(wps) is {:.2f}.".format(wps)) logging.info("Throughput(wps) is {:.2f}.".format(wps))
if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])): if not wps > (golden_config["avg_wps"] - (3 * golden_config["std_dev_wps"])):
raise RuntimeError( raise RuntimeError(
"Throughput(wps):{:.2f} is below the golden threshold of an " "Throughput(wps):{:.2f} is below the golden threshold of an "
...@@ -379,17 +320,17 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs, ...@@ -379,17 +320,17 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
epoch = benchmark_config["epochs"] epoch = benchmark_config["epochs"]
start_time = time.time() start_time = time.time()
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110) logging.debug("-" * 110)
print("| start of epoch {:1d}".format(epoch)) logging.debug("| start of epoch {:1d}".format(epoch))
print("-" * 110) logging.debug("-" * 110)
wps, loss = train(model_config, model, benchmark_config, model_specs, args) wps, loss = train(model_config, model, benchmark_config, model_specs, args)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if dist.get_rank() == dist.get_world_size() - 1: if dist.get_rank() == dist.get_world_size() - 1:
print("-" * 110) logging.debug("-" * 110)
print("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss)) logging.debug("| end of epoch {:1d} | time: {:5.2f}s | train loss {:5.2f} ".format(epoch, elapsed_time, loss))
print("-" * 110) logging.debug("-" * 110)
print("Throughput(wps) is {:.2f}.".format(wps)) logging.debug("Throughput(wps) is {:.2f}.".format(wps))
print( logging.debug(
"Peak allocated bytes on cuda:{}: {:1d}".format( "Peak allocated bytes on cuda:{}: {:1d}".format(
dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"] dist.get_rank(), torch.cuda.memory_stats(dist.get_rank())["allocated_bytes.all.peak"]
) )
...@@ -402,17 +343,6 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs, ...@@ -402,17 +343,6 @@ def benchmark_language_model(model_config, model, benchmark_config, model_specs,
raise RuntimeError("Unrecognized args.model_name " % args.model_name) raise RuntimeError("Unrecognized args.model_name " % args.model_name)
def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
balance = []
layers_assigned = 0
average_count = num_layers / num_devices
last_layers = int(average_count * fraction)
balance = generate_balance(num_devices - 1, num_layers - last_layers)
balance.append(last_layers)
return balance
def generate_balance(num_devices, num_layers): def generate_balance(num_devices, num_layers):
balance = [] balance = []
layers_assigned = 0 layers_assigned = 0
...@@ -539,9 +469,6 @@ def run_mp_worker(args, available_workers): ...@@ -539,9 +469,6 @@ def run_mp_worker(args, available_workers):
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
pipe_model = pipe_model.cuda() pipe_model = pipe_model.cuda()
if args.all_at_once and pipe_model.pipeline:
print(f"running all at once")
pipe_model.pipeline.all_at_once = True
if args.dry_run: if args.dry_run:
train(model_config, pipe_model, benchmark_config, model_specs, args) train(model_config, pipe_model, benchmark_config, model_specs, args)
...@@ -610,16 +537,16 @@ parser.add_argument( ...@@ -610,16 +537,16 @@ parser.add_argument(
default="lm", default="lm",
help="Language Model(LM) used to benchmark nn.pipe.", help="Language Model(LM) used to benchmark nn.pipe.",
) )
parser.add_argument("--debug", action="store_true", default=False, help="Display additional debug information")
if __name__ == "__main__": if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)
# TODO(anj-s): Remove print statements and introduce logging levels.
if not args.multiprocess: if not args.multiprocess:
print(f"Running single process benchmark with args: {args}") logging.info(f"Running single process benchmark with args: {args}")
benchmark_single_process(args) benchmark_single_process(args)
else: else:
world_size = max(torch.cuda.device_count(), 1) world_size = max(torch.cuda.device_count(), 1)
print(f"Running multiprocess benchmark with args: {args}") logging.info(f"Running multiprocess benchmark with args: {args}")
mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True) mp.spawn(benchmark_multiprocess, args=(world_size, args), nprocs=world_size, join=True)
...@@ -433,7 +433,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -433,7 +433,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())}
if pipe_world_size == 2: if pipe_world_size == 2:
print(f"actually doing pipe stuff now") print("actually doing pipe stuff now")
assert torch.equal(saved_weight_0, model[0].weight.data) assert torch.equal(saved_weight_0, model[0].weight.data)
assert torch.equal(saved_weight_2, model[2].weight.data) assert torch.equal(saved_weight_2, model[2].weight.data)
pipe_model = MultiProcessPipe( pipe_model = MultiProcessPipe(
...@@ -496,7 +496,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -496,7 +496,7 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
dump_opt_params(optimizer) dump_opt_params(optimizer)
optimizer.step() optimizer.step()
print(f"calling check_weights on master") print("calling check_weights on master")
check_weights(model, reference, "pipe", index=2) check_weights(model, reference, "pipe", index=2)
print(f"waiting for barrier on master, pid={os.getpid()}") print(f"waiting for barrier on master, pid={os.getpid()}")
else: else:
...@@ -511,11 +511,11 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False ...@@ -511,11 +511,11 @@ def run_test_pipe(rank, world_size, filename, filename_rpc, skip_dist_init=False
if failed: if failed:
raise RuntimeError("failed somehow") raise RuntimeError("failed somehow")
dump_opt_params(optimizer) dump_opt_params(optimizer)
print(f"calling step on slave") print("calling step on slave")
optimizer.step() optimizer.step()
print(f"calling check_weights on slave") print("calling check_weights on slave")
check_weights(model, reference, "pipe", index=0) check_weights(model, reference, "pipe", index=0)
print(f"waiting for barrier on slave") print("waiting for barrier on slave")
pipe_model.zero_grad() pipe_model.zero_grad()
torch.distributed.barrier() torch.distributed.barrier()
......
...@@ -58,7 +58,7 @@ def pytest_report_header() -> str: ...@@ -58,7 +58,7 @@ def pytest_report_header() -> str:
def pytest_runtest_setup(item: Any) -> None: def pytest_runtest_setup(item: Any) -> None:
print(f"setup mpi function called") print("setup mpi function called")
def pytest_runtest_teardown(item: Any) -> None: def pytest_runtest_teardown(item: Any) -> None:
......
...@@ -58,7 +58,7 @@ def pytest_report_header() -> str: ...@@ -58,7 +58,7 @@ def pytest_report_header() -> str:
def pytest_runtest_setup(item: Any) -> None: def pytest_runtest_setup(item: Any) -> None:
print(f"setup mpi function called") print("setup mpi function called")
def pytest_runtest_teardown(item: Any) -> None: def pytest_runtest_teardown(item: Any) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment