"vscode:/vscode.git/clone" did not exist on "6fac1369d0140ebeafdaeeef3f558f21fa3d5108"
Unverified Commit 8353c594 authored by eltonzheng's avatar eltonzheng Committed by GitHub
Browse files

reduce memcpy between host and device (#248)

parent bbd8cd7d
...@@ -362,7 +362,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -362,7 +362,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
for i, group in enumerate(self.fp16_groups): for i, group in enumerate(self.fp16_groups):
single_grad_partition = torch.zeros( single_grad_partition = torch.zeros(
int(self.partition_size[i]), int(self.partition_size[i]),
dtype=self.single_partition_of_fp32_groups[i].dtype).cuda() dtype=self.single_partition_of_fp32_groups[i].dtype,
device=torch.cuda.current_device())
self.single_partition_of_fp32_groups[i].grad = single_grad_partition self.single_partition_of_fp32_groups[i].grad = single_grad_partition
self.optimizer.step() self.optimizer.step()
...@@ -674,7 +675,8 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -674,7 +675,8 @@ class FP16_DeepSpeedZeroOptimizer(object):
see_memory_usage(f"before copying {total_size} gradients into partition") see_memory_usage(f"before copying {total_size} gradients into partition")
self.grads_in_partition = torch.empty(int(total_size), self.grads_in_partition = torch.empty(int(total_size),
dtype=torch.half).cuda() dtype=torch.half,
device=torch.cuda.current_device())
see_memory_usage(f"after copying {total_size} gradients into partition") see_memory_usage(f"after copying {total_size} gradients into partition")
#The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer #The allreduce buffer will be rewritted. Copy the gradients in partition to a new buffer
...@@ -1282,12 +1284,16 @@ class FP16_DeepSpeedZeroOptimizer(object): ...@@ -1282,12 +1284,16 @@ class FP16_DeepSpeedZeroOptimizer(object):
""" """
if self.contiguous_gradients: if self.contiguous_gradients:
self.ipg_buffer = [] self.ipg_buffer = []
buf_0 = torch.empty(self.reduce_bucket_size, dtype=torch.half).cuda() buf_0 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_0) self.ipg_buffer.append(buf_0)
# Use double buffers to avoid data access conflict when overlap_comm is enabled. # Use double buffers to avoid data access conflict when overlap_comm is enabled.
if self.overlap_comm: if self.overlap_comm:
buf_1 = torch.empty(self.reduce_bucket_size, dtype=torch.half).cuda() buf_1 = torch.empty(self.reduce_bucket_size,
dtype=torch.half,
device=torch.cuda.current_device())
self.ipg_buffer.append(buf_1) self.ipg_buffer.append(buf_1)
self.ipg_index = 0 self.ipg_index = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment