Commit 4fb37d45 authored by Max Ryabinin's avatar Max Ryabinin
Browse files

Extract get_tile_inds to a separate function

parent ac5550a0
...@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return True return True
def _get_tile_size(format):
assert format in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {format}"
return (8, 32) if format == "col_turing" else (32, 32)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
return get_inverse_transform_indices(transform, _get_tile_size(format)).to(device)
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
...@@ -267,20 +280,10 @@ class MatmulLtState: ...@@ -267,20 +280,10 @@ class MatmulLtState:
self.SBt = None self.SBt = None
self.CBt = None self.CBt = None
def get_tile_size(self):
assert self.formatB in (
"col_turing",
"col_ampere",
), f"please find this assert and manually enter tile size for {self.formatB}"
return (8, 32) if self.formatB == "col_turing" else (32, 32)
@property @property
def tile_indices(self): def tile_indices(self):
if self._tile_indices is None: if self._tile_indices is None:
device = self.CxB.device self._tile_indices = get_tile_inds(self.formatB, self.CxB.device)
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=self.formatB)[0].to(x.device)
with torch.no_grad():
self._tile_indices = get_inverse_transform_indices(transform, self.get_tile_size()).to(device)
return self._tile_indices return self._tile_indices
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment