utils.py 3.92 KB
Newer Older
WRH's avatar
WRH committed
1
2
3
4
5
6
7
8
9
10
# Copyright (c) OpenMMLab. All rights reserved.

import re

from transformers import (PreTrainedTokenizerFast, StoppingCriteria,
                          StoppingCriteriaList)
from transformers.generation.streamers import BaseStreamer


def get_utils(model):
11
12
    """Get utils by model type."""

WRH's avatar
WRH committed
13
14
15
16
17
18
19
20
21
22
23
24
25
    name = model.__class__.__name__
    if name == 'InferenceEngine':
        name = model.module.__class__.__name__

    if name == 'InternLMForCausalLM':
        stop_criteria = InternLMStoppingCriteria()
        stop_criteria = StoppingCriteriaList([stop_criteria])
        return InternLMDecorator, InternLMStreamer, stop_criteria
    else:
        return BaseDecorator, DecodeOutputStreamer, None


class DecodeOutputStreamer(BaseStreamer):
26
    """Default streamer for HuggingFace models."""
WRH's avatar
WRH committed
27
28
29
30
31
32
33
34
35
36
37
38
39

    def __init__(self, tokenizer, skip_prompt=True) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.gen_len = 0
        if isinstance(tokenizer, PreTrainedTokenizerFast):
            self.decode = self._decode_with_raw_id
            self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')
        else:
            self.decode = self._decode_fallback

    def _decode_with_raw_id(self, value):
40
41
        """Convert token ids to tokens and decode."""

WRH's avatar
WRH committed
42
43
44
45
46
47
48
49
50
51
52
53
54
        tok = self.tokenizer._convert_id_to_token(value)
        if tok.startswith('▁'):  # sentencepiece
            space = ' '
            tok = tok[1:]
        else:
            space = ''
        if res := self.hex_regex.match(tok):
            tok = chr(int(res.group(1), 16))
        if tok == '</s>':
            tok = '\n'
        return space + tok

    def _decode_fallback(self, value):
55
56
        """Fallback decoder for non-fast tokenizer."""

WRH's avatar
WRH committed
57
58
59
60
61
62
        tok = self.tokenizer.decode(value,
                                    skip_special_tokens=False,
                                    clean_up_tokenization_spaces=False)
        return tok + ' '

    def put(self, value):
63
64
        """Callback function to decode token and output to stdout."""

WRH's avatar
WRH committed
65
66
67
68
69
70
71
72
73
        if self.gen_len == 0 and self.skip_prompt:
            pass
        else:
            tok = self.decode(value[0])
            print(tok, end='', flush=True)

        self.gen_len += 1

    def end(self):
74
75
        """Callback function to finish generation."""

WRH's avatar
WRH committed
76
77
78
79
        print('\n')


class InternLMStreamer(DecodeOutputStreamer):
80
    """Streamer for InternLM."""
WRH's avatar
WRH committed
81
82
83
84
85
86
87
88
89

    def __init__(self, tokenizer, skip_prompt=True) -> None:
        BaseStreamer().__init__()
        self.tokenizer = tokenizer
        self.skip_prompt = skip_prompt
        self.gen_len = 0
        self.hex_regex = re.compile(r'^<0x([0-9ABCDEF]+)>$')

    def decode(self, value):
90
91
        """Decode generated tokens for InternLM."""

WRH's avatar
WRH committed
92
93
94
95
96
97
98
99
100
101
        tok = self.tokenizer.decode(value)
        if res := self.hex_regex.match(tok):
            tok = chr(int(res.group(1), 16))
        if tok == '</s>' or tok == '<eoa>':
            tok = '\n'

        return tok


class BaseDecorator:
102
    """Base decorator for decorating prompt and extracting generated output."""
WRH's avatar
WRH committed
103
104
105

    @classmethod
    def decorate(cls, prompt):
106
        """Abstract method for adding Add special tokens to prompt."""
WRH's avatar
WRH committed
107
108
109
110
        return prompt

    @classmethod
    def extract(cls, gen_out):
111
        """Abstract methods for extract generated output from model output."""
WRH's avatar
WRH committed
112
113
114
115
        return gen_out


class InternLMDecorator(BaseDecorator):
116
117
    """Decorator for InternLM."""

WRH's avatar
WRH committed
118
119
120
121
    regex = re.compile(r'<\|Bot\|>:(.*)')

    @classmethod
    def decorate(cls, prompt):
122
        """Decorate prompt for InternLM."""
WRH's avatar
WRH committed
123
124
125
126
        return f'<|User|>:{prompt}<eoh>'

    @classmethod
    def extract(cls, gen_out):
127
        """Extract generated tokens for InternLM."""
WRH's avatar
WRH committed
128
129
130
131
        return cls.regex.search(gen_out).group(1)


class InternLMStoppingCriteria(StoppingCriteria):
132
    """Stopping criteria for HF version of InternLM."""
WRH's avatar
WRH committed
133
134
135

    def __call__(self, input_ids, *args, **kwargs) -> bool:
        return input_ids[0, -1] in [2, 103028]