Commit 17035d6c authored by Jared Casper's avatar Jared Casper
Browse files

Replace deprecated torch.norm with torch.linalg.norm.

parent e7c7a78f
...@@ -112,7 +112,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2): ...@@ -112,7 +112,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
total_norm = 0 total_norm = 0
for p in parameters: for p in parameters:
if p.model_parallel or (get_model_parallel_rank() == 0): if p.model_parallel or (get_model_parallel_rank() == 0):
param_norm = p.grad.data.norm(norm_type) param_norm = torch.linalg.norm(p.grad.data.flatten(), norm_type)
total_norm += param_norm.item() ** norm_type total_norm += param_norm.item() ** norm_type
# Sum across all model parallel GPUs. # Sum across all model parallel GPUs.
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
......
...@@ -65,7 +65,7 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -65,7 +65,7 @@ def print_params_min_max_norm(optimizer, iteration):
index += 1 index += 1
min_ = param.data.min() min_ = param.data.min()
max_ = param.data.max() max_ = param.data.max()
norm = param.data.norm() norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel)) iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment