env.py 1.4 KB
Newer Older
wangkx1's avatar
init  
wangkx1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# env.py
import os
from typing import Dict, Any

class Environment:
    """管理环境变量和系统配置"""
    
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.base_env = config.get('base', {})
        self.device_id = self.base_env.get('common', {}).get('device_id', 3)
        
    def setup(self, model_env_vars: Dict[str, str] = None):
        """设置环境变量"""
        
        # 设置基础环境变量
        base_env_vars = self.base_env.get('env_vars', {})
        for key, value in base_env_vars.items():
            if key in os.environ:
                # 如果环境变量已存在,追加
                os.environ[key] = f"{value}:{os.environ[key]}"
            else:
                os.environ[key] = value
        
        # 设置模型特定的环境变量
        if model_env_vars:
            for key, value in model_env_vars.items():
                os.environ[key] = str(value)
        
        # 设置设备可见性
        os.environ['HIP_VISIBLE_DEVICES'] = str(self.device_id)
        
    def get_result_dir(self) -> str:
        """获取结果目录"""
        result_dir = self.base_env.get('common', {}).get('result_dir', '../result-v1')
        
        # 确保目录存在
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
            print(f"创建目录: {result_dir}")
            
        return result_dir