Unverified Commit 6f516b8d authored by Casper's avatar Casper Committed by GitHub
Browse files

Fixed multi-GPU quantization (#196)

parent 74d0fe44
...@@ -69,8 +69,15 @@ class AwqQuantizer: ...@@ -69,8 +69,15 @@ class AwqQuantizer:
def quantize(self): def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"): for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
self.modules[i] = self.modules[i].cuda()
common_device = next(self.modules[i].parameters()).device
self.inps = self.inps.to(common_device)
# [STEP 1]: Get layer, extract linear modules, extract input features # [STEP 1]: Get layer, extract linear modules, extract input features
self.modules[i] = self.modules[i].cuda()
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
input_feat = self._get_input_feat(self.modules[i], named_linears) input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory() clear_memory()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment