"official/vision/configs/video_classification.py" did not exist on "4a8191d6327a8b0bb86a52c74fd544eba51a3a3b"
shape.py 2.25 KB
Newer Older
Sugon_ldc's avatar
Sugon_ldc committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import os.path as osp


def check_gdal() -> bool:
    try:
        import gdal
    except:
        try:
            from osgeo import gdal
        except ImportError:
            return False
    return True


IMPORT_STATE = False
if check_gdal():
    try:
        import gdal
        import osr
        import ogr
    except:
        from osgeo import gdal, osr, ogr
    IMPORT_STATE = True


# 保存shp文件
def save_shp(shp_path: str, tif_path: str, ignore_index: int=0) -> str:
    if IMPORT_STATE == True:
        ds = gdal.Open(tif_path)
        srcband = ds.GetRasterBand(1)
        maskband = srcband.GetMaskBand()
        gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
        gdal.SetConfigOption("SHAPE_ENCODING", "UTF-8")
        ogr.RegisterAll()
        drv = ogr.GetDriverByName("ESRI Shapefile")
        if osp.exists(shp_path):
            os.remove(shp_path)
        dst_ds = drv.CreateDataSource(shp_path)
        prosrs = osr.SpatialReference(wkt=ds.GetProjection())
        dst_layer = dst_ds.CreateLayer(
            "segmentation", geom_type=ogr.wkbPolygon, srs=prosrs)
        dst_fieldname = "DN"
        fd = ogr.FieldDefn(dst_fieldname, ogr.OFTInteger)
        dst_layer.CreateField(fd)
        gdal.Polygonize(srcband, maskband, dst_layer, 0, [])
        lyr = dst_ds.GetLayer()
        lyr.SetAttributeFilter("DN = '{}'".format(str(ignore_index)))
        for holes in lyr:
            lyr.DeleteFeature(holes.GetFID())
        dst_ds.Destroy()
        ds = None
        return "Dataset creation successfully!"
    else:
        raise ImportError("can't import gdal, osr, ogr!")