"git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "349a499f339fd0e758ee8bd54c09ee073cfbaf3b"
Commit b491c2d6 authored by Casper's avatar Casper
Browse files

Replace print with logging, Remove uncommented code

parent f741f406
...@@ -34,8 +34,6 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -34,8 +34,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin() sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False) self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward( def forward(
...@@ -46,7 +44,6 @@ class QuantLlamaRotaryEmbedding(nn.Module): ...@@ -46,7 +44,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
): ):
# Apply rotary embedding to the query and key before passing them # Apply rotary embedding to the query and key before passing them
# to the attention op. # to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous() query = query.contiguous()
key = key.contiguous() key = key.contiguous()
awq_inference_engine.rotary_embedding_neox( awq_inference_engine.rotary_embedding_neox(
...@@ -146,7 +143,7 @@ def make_quant_attn(model, dev): ...@@ -146,7 +143,7 @@ def make_quant_attn(model, dev):
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx = None g_idx = None
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
...@@ -156,8 +153,6 @@ def make_quant_attn(model, dev): ...@@ -156,8 +153,6 @@ def make_quant_attn(model, dev):
qkv_layer.scales = scales qkv_layer.scales = scales
qkv_layer.bias = bias qkv_layer.bias = bias
# We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev) attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev)
if '.' in name: if '.' in name:
...@@ -169,6 +164,4 @@ def make_quant_attn(model, dev): ...@@ -169,6 +164,4 @@ def make_quant_attn(model, dev):
parent = model parent = model
child_name = name child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, attn) setattr(parent, child_name, attn)
...@@ -71,7 +71,6 @@ class QuantLlamaMLP(nn.Module): ...@@ -71,7 +71,6 @@ class QuantLlamaMLP(nn.Module):
def make_fused_mlp(m, parent_name=''): def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"): if not hasattr(make_fused_mlp, "called"):
# print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.")
make_fused_mlp.called = True make_fused_mlp.called = True
""" """
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
......
...@@ -38,6 +38,4 @@ def make_quant_norm(model): ...@@ -38,6 +38,4 @@ def make_quant_norm(model):
parent = model parent = model
child_name = name child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, norm) setattr(parent, child_name, norm)
import gc import gc
import torch import torch
import torch.nn as nn import torch.nn as nn
import logging
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
...@@ -154,9 +155,8 @@ def auto_scale_block(awq_model, ...@@ -154,9 +155,8 @@ def auto_scale_block(awq_model,
best_scales = scales best_scales = scales
block.load_state_dict(org_sd) block.load_state_dict(org_sd)
if best_ratio == -1: if best_ratio == -1:
print(history) logging.debug(history)
raise Exception raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1) best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales assert torch.isnan(best_scales).sum() == 0, best_scales
......
import torch import torch
import logging
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
...@@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size= ...@@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=
# now concatenate all samples and split according to block size # now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1) cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size n_split = cat_samples.shape[1] // block_size
print(f" * Split into {n_split} blocks") logging.debug(f" * Split into {n_split} blocks")
return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)] return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]
...@@ -2,7 +2,7 @@ import transformers ...@@ -2,7 +2,7 @@ import transformers
import torch import torch
from lm_eval.base import BaseLM from lm_eval.base import BaseLM
import fnmatch import fnmatch
import logging
class LMEvalAdaptor(BaseLM): class LMEvalAdaptor(BaseLM):
...@@ -52,7 +52,7 @@ class LMEvalAdaptor(BaseLM): ...@@ -52,7 +52,7 @@ class LMEvalAdaptor(BaseLM):
elif 'falcon' in self.model_name: elif 'falcon' in self.model_name:
return 2048 return 2048
else: else:
print(self.model.config) logging.debug(self.model.config)
raise NotImplementedError raise NotImplementedError
@property @property
......
import os import os
import torch import torch
import gc import gc
import logging
def auto_parallel(args): def auto_parallel(args):
...@@ -23,5 +24,5 @@ def auto_parallel(args): ...@@ -23,5 +24,5 @@ def auto_parallel(args):
cuda_visible_devices = list(range(8)) cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]]) [str(dev) for dev in cuda_visible_devices[:n_gpu]])
print("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices return cuda_visible_devices
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment