Unverified Commit d119fc86 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[CI][Bugfix] Fix failing Blackwell test (#24993)


Signed-off-by: default avatarMatthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: default avatarRobert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
parent dbebb7f8
...@@ -506,12 +506,9 @@ class SharedResizableBuffer: ...@@ -506,12 +506,9 @@ class SharedResizableBuffer:
def get(self, shape: tuple[int, ...], device: torch.device, def get(self, shape: tuple[int, ...], device: torch.device,
dtype: torch.dtype): dtype: torch.dtype):
shape_numel = prod(shape) shape_numel = prod(shape)
if self.buffer is None or self.buffer.numel() < shape_numel: if (self.buffer is None or self.buffer.numel() < shape_numel
or self.buffer.device != device or self.buffer.dtype != dtype):
self.buffer = torch.empty(shape_numel, device=device, dtype=dtype) self.buffer = torch.empty(shape_numel, device=device, dtype=dtype)
assert self.buffer.device == device, \
f"Buffer device mismatch: {self.buffer.device} != {device}"
assert self.buffer.dtype == dtype, \
f"Buffer dtype mismatch: {self.buffer.dtype} != {dtype}"
return self.buffer[:shape_numel].view(*shape) return self.buffer[:shape_numel].view(*shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment