training_context.py 1.21 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import uuid
from dataclasses import dataclass

@dataclass
class TrainingContext:
    
    dp: int = -1
    tp: int = 1
    pp: int = 1
    
    def __post_init__(self):
        self.world_size = int(os.environ.get('WORLD_SIZE', 1))
        self.local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
        self.rank = int(os.environ.get('RANK', 0))
        self.local_rank = int(os.environ.get('LOCAL_RANK', 0))
        
        self.run_uuid = str(uuid.uuid4())
        self.training_uuid = self.run_uuid
        self.iter_num = 0
        self.best_train_loss = 1e2
        self.best_val_loss = 1e2
        self.processed_tokens = 0
        
        self._validate()

    def _validate(self):
        if self.pp < 1:
            self.pp = 1
        if self.tp < 1:
            self.tp = 1
        if self.dp < 1:
            self.dp = self.world_size // (self.tp * self.pp)
        
        assert self.dp > 0
        assert self.tp > 0
        assert self.pp == 1, f"pp({self.pp}) > 1 is not supported"
        assert self.dp * self.tp * self.pp == self.world_size, f"dp({self.dp}) * tp({self.tp}) * pp({self.pp}) != world_size({self.world_size})"

    @property
    def master_process(self):
        return self.rank == 0