Unverified Commit b6dc35fe authored by WRH's avatar WRH Committed by GitHub
Browse files

[Improve] Add docstrings to pytorch submodule (#93)

* add some docstrings.

* update docstring.

fix

* ignore magic methods
parent 69b6eabe
......@@ -39,7 +39,7 @@ jobs:
- name: Check docstring coverage
run: |
python -m pip install interrogate
interrogate -v --ignore-init-method --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
interrogate -v --ignore-init-method --ignore-magic --ignore-module --ignore-private --ignore-nested-functions --ignore-nested-classes --fail-under 80 lmdeploy
- name: Check pylint score
run: |
python -m pip install pylint
......
......@@ -15,6 +15,8 @@ class LoadNoInit:
self.kaiming_normal_ = torch.nn.init.kaiming_normal_
def __enter__(self, *args, **kwargs):
"""Replace initializers with no-op."""
torch.nn.init.constant_ = lambda *args, **kwargs: None
torch.nn.init.zeros_ = lambda *args, **kwargs: None
torch.nn.init.ones_ = lambda *args, **kwargs: None
......@@ -24,6 +26,8 @@ class LoadNoInit:
torch.nn.init.kaiming_normal_ = lambda *args, **kwargs: None
def __exit__(self, *args, **kwargs):
"""Recover."""
torch.nn.init.constant_ = self.constant_
torch.nn.init.zeros_ = self.zeros_
torch.nn.init.ones_ = self.ones_
......
......@@ -26,6 +26,8 @@ except ImportError:
def input_prompt():
"""Helper function for getting input from users."""
print('\ndouble enter to end input >>> ', end='')
sentinel = '' # ends when this string is seen
return '\n'.join(iter(input, sentinel))
......@@ -38,10 +40,19 @@ def init_model(
local_rank=0,
world_size=1,
):
"""Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""Initialize model and tokenizer from given path.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
local_rank (int): Local rank of current process.
world_size (int): World size of current process.
Note:
If the model is converted from new version of transformers,
use_fast_tokenizer should be True.
If using depodaca/llama-xb-hf, use_fast_tokenizer should be False.
"""
if not _is_transformers_available:
......@@ -86,6 +97,18 @@ def main(
seed: int = 0,
use_fast_tokenizer: bool = True,
):
"""Start chat session with given model.
Args:
model_path (str): Path to model.
tokenizer_path (str): Path to tokenizer.
max_new_tokens (int): Maximum number of tokens to generate.
temperature (float): Temperature for sampling.
top_p (float): Top p for sampling.
seed (int): Random seed.
use_fast_tokenizer (bool): Whether to use fast tokenizer.
"""
torch.manual_seed(seed)
local_rank = int(os.getenv('LOCAL_RANK', '0'))
......
......@@ -8,6 +8,8 @@ from transformers.generation.streamers import BaseStreamer
def get_utils(model):
"""Get utils by model type."""
name = model.__class__.__name__
if name == 'InferenceEngine':
name = model.module.__class__.__name__
......@@ -21,7 +23,7 @@ def get_utils(model):
class DecodeOutputStreamer(BaseStreamer):
"""Output generated tokens to shell."""
"""Default streamer for HuggingFace models."""
def __init__(self, tokenizer, skip_prompt=True) -> None:
super().__init__()
......@@ -35,6 +37,8 @@ class DecodeOutputStreamer(BaseStreamer):
self.decode = self._decode_fallback
def _decode_with_raw_id(self, value):
"""Convert token ids to tokens and decode."""
tok = self.tokenizer._convert_id_to_token(value)
if tok.startswith('▁'): # sentencepiece
space = ' '
......@@ -48,12 +52,16 @@ class DecodeOutputStreamer(BaseStreamer):
return space + tok
def _decode_fallback(self, value):
"""Fallback decoder for non-fast tokenizer."""
tok = self.tokenizer.decode(value,
skip_special_tokens=False,
clean_up_tokenization_spaces=False)
return tok + ' '
def put(self, value):
"""Callback function to decode token and output to stdout."""
if self.gen_len == 0 and self.skip_prompt:
pass
else:
......@@ -63,11 +71,13 @@ class DecodeOutputStreamer(BaseStreamer):
self.gen_len += 1
def end(self):
"""Callback function to finish generation."""
print('\n')
class InternLMStreamer(DecodeOutputStreamer):
"""Output generated tokens to shell."""
"""Streamer for InternLM."""
def __init__(self, tokenizer, skip_prompt=True) -> None:
BaseStreamer().__init__()
......@@ -77,6 +87,8 @@ class InternLMStreamer(DecodeOutputStreamer):
self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
def decode(self, value):
"""Decode generated tokens for InternLM."""
tok = self.tokenizer.decode(value)
if res := self.hex_regex.match(tok):
tok = chr(int(res.group(1), 16))
......@@ -87,29 +99,37 @@ class InternLMStreamer(DecodeOutputStreamer):
class BaseDecorator:
"""Base decorator for decorating prompt and extracting generated output."""
@classmethod
def decorate(cls, prompt):
"""Abstract method for adding Add special tokens to prompt."""
return prompt
@classmethod
def extract(cls, gen_out):
"""Abstract methods for extract generated output from model output."""
return gen_out
class InternLMDecorator(BaseDecorator):
"""Decorator for InternLM."""
regex = re.compile(r'<\|Bot\|>:(.*)')
@classmethod
def decorate(cls, prompt):
"""Decorate prompt for InternLM."""
return f'<|User|>:{prompt}<eoh>'
@classmethod
def extract(cls, gen_out):
"""Extract generated tokens for InternLM."""
return cls.regex.search(gen_out).group(1)
class InternLMStoppingCriteria(StoppingCriteria):
"""Stopping criteria for HF version of InternLM."""
def __call__(self, input_ids, *args, **kwargs) -> bool:
return input_ids[0, -1] in [2, 103028]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment