attentions.py 1.63 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from allamo.configuration import AllamoConfiguration
from allamo.logging import logger

class AttentionVersion:
    """
    Versions:
    0 - eager
    1 - SDPA
    2 - FA2
    """
    
    def __init__(self):
        self.disable_flash_attn_2()
        
    def configure(self, config: AllamoConfiguration):
        if config.attention_implementation:
            if config.attention_implementation == 'flash_attention_2':
                self.enable_flash_attn_2()
            elif config.attention_implementation == 'sdpa':
                self.disable_flash_attn_2()
            elif config.attention_implementation == 'eager':
                self.force_eager()
    
    def disable_flash_attn_2(self):
        self.version = 1 if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else 0
        self.flash_attn_2_supports_window_size = False
    
    def enable_flash_attn_2(self):
        self.version = 2
        self.flash_attn_2_supports_window_size = True
        
    def force_eager(self):
        self.version = 0
        self.flash_attn_2_supports_window_size = False
    
    def log_version(self, sliding_window):
        if self.version == 2:
            logger.info("Using Flash Attention 2")
            if self.flash_attn_2_supports_window_size and sliding_window:
                logger.info("Using sliding window")
        elif self.version == 1:
            logger.info("Using scaled_dot_product_attention")
        elif self.version == 0:
            logger.info("WARNING: using slow attention")
        else:
            raise Exception('Unsupported Flash Attention version!')
    
attention_version = AttentionVersion()