Unverified Commit e2c39426 authored by Mehdi Mirzazadeh's avatar Mehdi Mirzazadeh Committed by GitHub
Browse files

Fixing memory leak in distributed pipeline (#724)

* Fixing memory lead in distributed pipeline

* fix mypy error
parent 681606f0
...@@ -273,14 +273,19 @@ class PartitionHandler: ...@@ -273,14 +273,19 @@ class PartitionHandler:
pipeline_record = pipeline_record_rref.local_value() pipeline_record = pipeline_record_rref.local_value()
self.run(pipeline_record) self.run(pipeline_record)
result: Optional[Tensor] = None
if not pipeline_record.consumers: if not pipeline_record.consumers:
result = microbatch.gather(pipeline_record.batches) gather_result = microbatch.gather(pipeline_record.batches)
assert len(result) == 1 assert len(gather_result) == 1
result = result[0] result = gather_result[0]
s0 = current_stream(result.device) s0 = current_stream(result.device)
if is_cuda(s0): if is_cuda(s0):
# TODO. Investigate why this is needed and remove it if possible. # TODO. Investigate why this is needed and remove it if possible.
as_cuda(s0).synchronize() as_cuda(s0).synchronize()
return result
return None # TODO: There seems to be a memory leak that is solved by following line.
# Investigate why is it needed.
del pipeline_record.batches
return result
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment