push_to_hub.py 6.87 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Copyright (c) Alibaba, Inc. and its affiliates.

import concurrent.futures
import os
import shutil
from multiprocessing import Manager, Process, Value

from swift.utils.logger import get_logger
from .api import HubApi
from .constants import DEFAULT_REPOSITORY_REVISION, ModelVisibility

logger = get_logger()

_executor = concurrent.futures.ProcessPoolExecutor(max_workers=8)
_queues = dict()
_flags = dict()
_tasks = dict()
_manager = None


def _api_push_to_hub(repo_name,
                     output_dir,
                     token,
                     private=True,
                     commit_message='',
                     tag=None,
                     source_repo='',
                     ignore_file_pattern=None,
                     revision=DEFAULT_REPOSITORY_REVISION):
    try:
        api = HubApi()
        api.login(token)
        api.push_model(
            repo_name,
            output_dir,
            visibility=ModelVisibility.PUBLIC if not private else ModelVisibility.PRIVATE,
            chinese_name=repo_name,
            commit_message=commit_message,
            tag=tag,
            original_model_id=source_repo,
            ignore_file_pattern=ignore_file_pattern,
            revision=revision)
        commit_message = commit_message or 'No commit message'
        logger.info(f'Successfully upload the model to {repo_name} with message: {commit_message}')
        return True
    except Exception as e:
        logger.error(f'Error happens when uploading model {repo_name} with message: {commit_message}: {e}')
        return False


def push_to_hub(repo_name,
                output_dir,
                token=None,
                private=True,
                retry=3,
                commit_message='',
                tag=None,
                source_repo='',
                ignore_file_pattern=None,
                revision=DEFAULT_REPOSITORY_REVISION):
    """
    Args:
        repo_name: The repo name for the modelhub repo
        output_dir: The local output_dir for the checkpoint
        token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
        private: If is a private repo, default True
        retry: Retry times if something error in uploading, default 3
        commit_message: The commit message
        tag: The tag of this commit
        source_repo: The source repo (model id) which this model comes from
        ignore_file_pattern: The file pattern to be ignored in uploading.
        revision: The branch to commit to
    Returns:
        The boolean value to represent whether the model is uploaded.
    """
    if token is None:
        token = os.environ.get('MODELSCOPE_API_TOKEN')
    if ignore_file_pattern is None:
        ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
    assert repo_name is not None
    assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
    assert os.path.isdir(output_dir)
    assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
           or 'configuration.yml' in os.listdir(output_dir)

    logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}')
    for i in range(retry):
        if _api_push_to_hub(repo_name, output_dir, token, private, commit_message, tag, source_repo,
                            ignore_file_pattern, revision):
            return True
    return False


def push_to_hub_async(repo_name,
                      output_dir,
                      token=None,
                      private=True,
                      commit_message='',
                      tag=None,
                      source_repo='',
                      ignore_file_pattern=None,
                      revision=DEFAULT_REPOSITORY_REVISION):
    """
    Args:
        repo_name: The repo name for the modelhub repo
        output_dir: The local output_dir for the checkpoint
        token: The user api token, function will check the `MODELSCOPE_API_TOKEN` variable if this argument is None
        private: If is a private repo, default True
        commit_message: The commit message
        tag: The tag of this commit
        source_repo: The source repo (model id) which this model comes from
        ignore_file_pattern: The file pattern to be ignored in uploading
        revision: The branch to commit to
    Returns:
        A handler to check the result and the status
    """
    if token is None:
        token = os.environ.get('MODELSCOPE_API_TOKEN')
    if ignore_file_pattern is None:
        ignore_file_pattern = os.environ.get('UPLOAD_IGNORE_FILE_PATTERN')
    assert repo_name is not None
    assert token is not None, 'Either pass in a token or to set `MODELSCOPE_API_TOKEN` in the environment variables.'
    assert os.path.isdir(output_dir)
    assert 'configuration.json' in os.listdir(output_dir) or 'configuration.yaml' in os.listdir(output_dir) \
           or 'configuration.yml' in os.listdir(output_dir)

    logger.info(f'Uploading {output_dir} to {repo_name} with message {commit_message}')
    return _executor.submit(_api_push_to_hub, repo_name, output_dir, token, private, commit_message, tag, source_repo,
                            ignore_file_pattern, revision)


def submit_task(q, b):
    while True:
        b.value = False
        item = q.get()
        logger.info(item)
        b.value = True
        if not item.pop('done', False):
            delete_dir = item.pop('delete_dir', False)
            output_dir = item.get('output_dir')
            try:
                push_to_hub(**item)
                if delete_dir and os.path.exists(output_dir):
                    shutil.rmtree(output_dir)
            except Exception as e:
                logger.error(e)
        else:
            break


class UploadStrategy:
    cancel = 'cancel'
    wait = 'wait'


def push_to_hub_in_queue(queue_name, strategy=UploadStrategy.cancel, **kwargs):
    assert queue_name is not None and len(queue_name) > 0, 'Please specify a valid queue name!'
    global _manager
    if _manager is None:
        _manager = Manager()
    if queue_name not in _queues:
        _queues[queue_name] = _manager.Queue()
        _flags[queue_name] = Value('b', False)
        process = Process(target=submit_task, args=(_queues[queue_name], _flags[queue_name]))
        process.start()
        _tasks[queue_name] = process

    queue = _queues[queue_name]
    flag: Value = _flags[queue_name]
    if kwargs.get('done', False):
        queue.put(kwargs)
    elif flag.value and strategy == UploadStrategy.cancel:
        logger.error(f'Another uploading is running, '
                     f'this uploading with message {kwargs.get("commit_message")} will be canceled.')
    else:
        queue.put(kwargs)


def wait_for_done(queue_name):
    process: Process = _tasks.pop(queue_name, None)
    if process is None:
        return
    process.join()

    _queues.pop(queue_name)
    _flags.pop(queue_name)