Commit b491c2d6 authored by Casper's avatar Casper
Browse files

Replace print with logging, Remove uncommented code

parent f741f406
......@@ -34,8 +34,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("cos_sin_cache", cache.half(), persistent=False)
def forward(
......@@ -46,7 +44,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query = query.contiguous()
key = key.contiguous()
awq_inference_engine.rotary_embedding_neox(
......@@ -146,7 +143,7 @@ def make_quant_attn(model, dev):
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx = None
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
......@@ -156,8 +153,6 @@ def make_quant_attn(model, dev):
qkv_layer.scales = scales
qkv_layer.bias = bias
# We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev)
if '.' in name:
......@@ -169,6 +164,4 @@ def make_quant_attn(model, dev):
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, attn)
......@@ -71,7 +71,6 @@ class QuantLlamaMLP(nn.Module):
def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"):
# print("[Warning] Calling a fake MLP fusion. But still faster than Huggingface Implimentation.")
make_fused_mlp.called = True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
......
......@@ -38,6 +38,4 @@ def make_quant_norm(model):
parent = model
child_name = name
#print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
setattr(parent, child_name, norm)
import gc
import torch
import torch.nn as nn
import logging
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer
......@@ -154,9 +155,8 @@ def auto_scale_block(awq_model,
best_scales = scales
block.load_state_dict(org_sd)
if best_ratio == -1:
print(history)
logging.debug(history)
raise Exception
# print(best_ratio)
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
......
import torch
import logging
from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
......@@ -25,5 +26,5 @@ def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
print(f" * Split into {n_split} blocks")
logging.debug(f" * Split into {n_split} blocks")
return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]
......@@ -2,7 +2,7 @@ import transformers
import torch
from lm_eval.base import BaseLM
import fnmatch
import logging
class LMEvalAdaptor(BaseLM):
......@@ -52,7 +52,7 @@ class LMEvalAdaptor(BaseLM):
elif 'falcon' in self.model_name:
return 2048
else:
print(self.model.config)
logging.debug(self.model.config)
raise NotImplementedError
@property
......
import os
import torch
import gc
import logging
def auto_parallel(args):
......@@ -23,5 +24,5 @@ def auto_parallel(args):
cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]])
print("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment