user_manager.py 6.98 KB
Newer Older
1
2
3
4
import json
import os
import re
import uuid
5
6
import glob
import shutil
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from aiohttp import web
from comfy.cli_args import args
from folder_paths import user_directory
from .app_settings import AppSettings

default_user = "default"
users_file = os.path.join(user_directory, "users.json")


class UserManager():
    def __init__(self):
        global user_directory

        self.settings = AppSettings(self)
        if not os.path.exists(user_directory):
            os.mkdir(user_directory)
            if not args.multi_user:
                print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
                print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

        if args.multi_user:
            if os.path.isfile(users_file):
                with open(users_file) as f:
                    self.users = json.load(f)
            else:
                self.users = {}
        else:
            self.users = {"default": "default"}

    def get_request_user_id(self, request):
        user = "default"
        if args.multi_user and "comfy-user" in request.headers:
            user = request.headers["comfy-user"]

        if user not in self.users:
            raise KeyError("Unknown user: " + user)

        return user

    def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
        global user_directory

        if type == "userdata":
            root_dir = user_directory
        else:
            raise KeyError("Unknown filepath type:" + type)

        user = self.get_request_user_id(request)
        path = user_root = os.path.abspath(os.path.join(root_dir, user))

        # prevent leaving /{type}
        if os.path.commonpath((root_dir, user_root)) != root_dir:
            return None

        if file is not None:
            # prevent leaving /{type}/{user}
            path = os.path.abspath(os.path.join(user_root, file))
            if os.path.commonpath((user_root, path)) != user_root:
                return None

67
68
        parent = os.path.split(path)[0]

69
        if create_dir and not os.path.exists(parent):
70
            os.makedirs(parent, exist_ok=True)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98

        return path

    def add_user(self, name):
        name = name.strip()
        if not name:
            raise ValueError("username not provided")
        user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
        user_id = user_id + "_" + str(uuid.uuid4())

        self.users[user_id] = name

        global users_file
        with open(users_file, "w") as f:
            json.dump(self.users, f)

        return user_id

    def add_routes(self, routes):
        self.settings.add_routes(routes)

        @routes.get("/users")
        async def get_users(request):
            if args.multi_user:
                return web.json_response({"storage": "server", "users": self.users})
            else:
                user_dir = self.get_request_user_filepath(request, None, create_dir=False)
                return web.json_response({
99
                    "storage": "server",
100
101
102
103
104
105
106
107
108
109
110
111
112
                    "migrated": os.path.exists(user_dir)
                })

        @routes.post("/users")
        async def post_users(request):
            body = await request.json()
            username = body["username"]
            if username in self.users.values():
                return web.json_response({"error": "Duplicate username."}, status=400)

            user_id = self.add_user(username)
            return web.json_response(user_id)

113
114
115
116
        @routes.get("/userdata")
        async def listuserdata(request):
            directory = request.rel_url.query.get('dir', '')
            if not directory:
117
118
                return web.Response(status=400)
                
119
            path = self.get_request_user_filepath(request, directory)
120
121
122
123
124
125
            if not path:
                return web.Response(status=403)
            
            if not os.path.exists(path):
                return web.Response(status=404)
            
126
127
128
129
130
131
132
133
            recurse = request.rel_url.query.get('recurse', '').lower() == "true"
            results = glob.glob(os.path.join(
                glob.escape(path), '**/*'), recursive=recurse)
            results = [os.path.relpath(x, path) for x in results if os.path.isfile(x)]
            
            split_path = request.rel_url.query.get('split', '').lower() == "true"
            if split_path:
                results = [[x] + x.split(os.sep) for x in results]
134

135
136
137
138
            return web.json_response(results)

        def get_user_data_path(request, check_exists = False, param = "file"):
            file = request.match_info.get(param, None)
139
140
141
142
143
144
            if not file:
                return web.Response(status=400)
                
            path = self.get_request_user_filepath(request, file)
            if not path:
                return web.Response(status=403)
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
            
            if check_exists and not os.path.exists(path):
                return web.Response(status=404)
            
            return path

        @routes.get("/userdata/{file}")
        async def getuserdata(request):
            path = get_user_data_path(request, check_exists=True)
            if not isinstance(path, str):
                return path
            
            return web.FileResponse(path)

        @routes.post("/userdata/{file}")
        async def post_userdata(request):
            path = get_user_data_path(request)
            if not isinstance(path, str):
                return path
            
            overwrite = request.query["overwrite"] != "false"
            if not overwrite and os.path.exists(path):
                return web.Response(status=409)
168
169

            body = await request.read()
170

171
172
173
            with open(path, "wb") as f:
                f.write(body)
                
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
            resp = os.path.relpath(path, self.get_request_user_filepath(request, None))
            return web.json_response(resp)

        @routes.delete("/userdata/{file}")
        async def delete_userdata(request):
            path = get_user_data_path(request, check_exists=True)
            if not isinstance(path, str):
                return path

            os.remove(path)
                
            return web.Response(status=204)

        @routes.post("/userdata/{file}/move/{dest}")
        async def move_userdata(request):
            source = get_user_data_path(request, check_exists=True)
            if not isinstance(source, str):
                return source
            
            dest = get_user_data_path(request, check_exists=False, param="dest")
            if not isinstance(source, str):
                return dest
            
            overwrite = request.query["overwrite"] != "false"
            if not overwrite and os.path.exists(dest):
                return web.Response(status=409)

            print(f"moving '{source}' -> '{dest}'")
            shutil.move(source, dest)
                
            resp = os.path.relpath(dest, self.get_request_user_filepath(request, None))
            return web.json_response(resp)